// Package commands provides CLI commands for the admin tool
package commands
import (
"context"
"database/sql"
"os"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/spf13/cobra"
)
// DatabaseCommands returns the database management commands
func DatabaseCommands(userService *services.UserService, logger *observability.Logger, db *sql.DB) *cobra.Command {
dbCmd := &cobra.Command{
Use: "db",
Short: "Database management commands",
Long: `Database management commands for the quiz application.
Available commands:
stats - Show database statistics
cleanup - Run database cleanup operations`,
}
// Add subcommands
dbCmd.AddCommand(statsCmd(userService, logger, db))
dbCmd.AddCommand(cleanupCmd(logger, db))
return dbCmd
}
// statsCmd returns the stats command
func statsCmd(userService *services.UserService, logger *observability.Logger, db *sql.DB) *cobra.Command {
return &cobra.Command{
Use: "stats",
Short: "Show database statistics",
Long: `Show database statistics including user counts and other metrics.`,
RunE: runStats(userService, logger, db),
}
}
// cleanupCmd returns the cleanup command
func cleanupCmd(logger *observability.Logger, db *sql.DB) *cobra.Command {
var statsOnly bool
cmd := &cobra.Command{
Use: "cleanup",
Short: "Run database cleanup operations",
Long: `Run database cleanup operations to remove old data.
This command will:
- Remove questions with legacy question types
- Remove orphaned user responses
Use --stats flag to see what would be cleaned up without actually performing the cleanup.`,
RunE: runCleanup(logger, &statsOnly, db),
}
// Add flags
cmd.Flags().BoolVar(&statsOnly, "stats", false, "Only show cleanup statistics, don't perform cleanup")
return cmd
}
// runStats returns a function that shows database statistics
func runStats(userService *services.UserService, logger *observability.Logger, db *sql.DB) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
// Log diagnostic information
logger.Info(ctx, "Diagnostic info", map[string]interface{}{"config_file": os.Getenv("QUIZ_CONFIG_FILE"), "database": getDatabaseInfo(db)})
logger.Info(ctx, "Showing database statistics", map[string]interface{}{})
// Get user statistics
users, err := userService.GetAllUsers(ctx)
if err != nil {
logger.Error(ctx, "Failed to get user statistics", err, map[string]interface{}{})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to get user statistics: %v", err)
}
logger.Info(ctx, "Database statistics", map[string]interface{}{"total_users": len(users), "database": "PostgreSQL", "status": "Connected"})
return nil
}
}
// runCleanup returns a function that runs database cleanup
func runCleanup(logger *observability.Logger, statsOnly *bool, db *sql.DB) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
// Log diagnostic information
logger.Info(ctx, "Diagnostic info", map[string]interface{}{"config_file": os.Getenv("QUIZ_CONFIG_FILE"), "database": getDatabaseInfo(db)})
logger.Info(ctx, "Running database cleanup", map[string]interface{}{"stats_only": *statsOnly})
// Use the database connection passed as parameter
if db == nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "database connection not available")
}
// Initialize cleanup service
cleanupService := services.NewCleanupServiceWithLogger(db, logger)
if *statsOnly {
// Show cleanup statistics only
stats, err := cleanupService.GetCleanupStats(ctx)
if err != nil {
logger.Error(ctx, "Failed to get cleanup stats", err, map[string]interface{}{"stats_only": true})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to get cleanup stats: %v", err)
}
logger.Info(ctx, "Database cleanup statistics", map[string]interface{}{"legacy_questions": stats["legacy_questions"], "orphaned_responses": stats["orphaned_responses"]})
total := stats["legacy_questions"] + stats["orphaned_responses"]
if total == 0 {
logger.Info(ctx, "No cleanup needed - database is clean", map[string]interface{}{"total": total})
} else {
logger.Info(ctx, "Cleanup would remove items", map[string]interface{}{"total": total})
}
return nil
}
// Run full cleanup
logger.Info(ctx, "Starting database cleanup", map[string]interface{}{"service": "cleanup"})
if err := cleanupService.RunFullCleanup(ctx); err != nil {
logger.Error(ctx, "Cleanup failed", err, map[string]interface{}{"service": "cleanup"})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "cleanup failed: %v", err)
}
logger.Info(ctx, "Database cleanup completed successfully", map[string]interface{}{"service": "cleanup"})
return nil
}
}
// Package commands provides CLI commands for the admin tool
package commands
import (
"context"
"database/sql"
"fmt"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/spf13/cobra"
)
// TranslationCommands returns the translation management commands
func TranslationCommands(logger *observability.Logger, db *sql.DB) *cobra.Command {
translationCmd := &cobra.Command{
Use: "translation",
Short: "Translation cache management commands",
Long: `Translation cache management commands for the quiz application.
Available commands:
cleanup - Remove expired translation cache entries`,
}
// Add subcommands
translationCmd.AddCommand(translationCleanupCmd(logger, db))
return translationCmd
}
// translationCleanupCmd returns the cleanup command for translation cache
func translationCleanupCmd(logger *observability.Logger, db *sql.DB) *cobra.Command {
var dryRun bool
cmd := &cobra.Command{
Use: "cleanup",
Short: "Remove expired translation cache entries",
Long: `Remove expired translation cache entries from the database.
This command will:
- Delete all translation cache entries that have expired (older than 30 days)
- Report the number of entries deleted
Use --dry-run flag to see what would be cleaned up without actually performing the cleanup.`,
RunE: runTranslationCleanup(logger, &dryRun, db),
}
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Show what would be cleaned up without actually performing the cleanup")
return cmd
}
// runTranslationCleanup executes the translation cache cleanup
func runTranslationCleanup(logger *observability.Logger, dryRun *bool, db *sql.DB) func(*cobra.Command, []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
cacheRepo := services.NewTranslationCacheRepository(db, logger)
if *dryRun {
// Count expired entries without deleting
var count int64
err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM translation_cache WHERE expires_at < NOW()").Scan(&count)
if err != nil {
logger.Error(ctx, "Failed to count expired translation cache entries", err)
return contextutils.WrapError(err, "failed to count expired entries")
}
fmt.Printf("Dry run: Would delete %d expired translation cache entries\n", count)
return nil
}
// Perform actual cleanup
count, err := cacheRepo.CleanupExpiredTranslations(ctx)
if err != nil {
logger.Error(ctx, "Failed to cleanup expired translation cache entries", err)
return contextutils.WrapError(err, "failed to cleanup expired entries")
}
fmt.Printf("Successfully deleted %d expired translation cache entries\n", count)
logger.Info(ctx, "Translation cache cleanup completed", map[string]interface{}{
"deleted_count": count,
})
return nil
}
}
package commands
import (
"context"
"fmt"
"os"
"syscall"
"golang.org/x/term"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/spf13/cobra"
)
// UserCommands returns the user management commands
func UserCommands(userService *services.UserService, logger *observability.Logger, databaseURL string) *cobra.Command {
userCmd := &cobra.Command{
Use: "user",
Short: "User management commands",
Long: `User management commands for the quiz application.
Available commands:
list - List all users
reset-password - Reset password for a specific user`,
}
// Add subcommands
userCmd.AddCommand(listCmd(userService, logger, databaseURL))
userCmd.AddCommand(resetPasswordCmd(userService, logger))
return userCmd
}
// listCmd returns the list command
func listCmd(userService *services.UserService, logger *observability.Logger, databaseURL string) *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all users",
Long: `List all users in the database with their basic information.`,
RunE: runListUsers(userService, logger, databaseURL),
}
}
// resetPasswordCmd returns the reset-password command
func resetPasswordCmd(userService *services.UserService, logger *observability.Logger) *cobra.Command {
return &cobra.Command{
Use: "reset-password [username]",
Short: "Reset password for a user",
Long: `Reset the password for a specific user. If username is not provided, you will be prompted for it.`,
RunE: runResetPassword(userService, logger),
}
}
// runListUsers returns a function that lists all users
func runListUsers(userService *services.UserService, logger *observability.Logger, databaseURL string) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
// Show diagnostic information
logger.Info(ctx, "Admin command diagnostics", map[string]interface{}{"config_file": os.Getenv("QUIZ_CONFIG_FILE"), "database_url": maskDatabaseURL(databaseURL)})
logger.Info(ctx, "Listing all users", map[string]interface{}{})
users, err := userService.GetAllUsers(ctx)
if err != nil {
logger.Error(ctx, "Failed to get users", err, map[string]interface{}{})
return contextutils.WrapError(err, "failed to get users")
}
if len(users) == 0 {
logger.Info(ctx, "No users found in the database", nil)
return nil
}
// Print header to stdout (user-facing table)
fmt.Printf("%-5s %-20s %-30s %-15s %-10s %-10s %-10s\n", "ID", "Username", "Email", "Language", "Level", "AI Enabled", "Created")
fmt.Println(string(make([]byte, 120))) // Print 120 dashes
// Print each user
for _, user := range users {
aiEnabled := "No"
if user.AIEnabled.Valid && user.AIEnabled.Bool {
aiEnabled = "Yes"
}
email := "N/A"
if user.Email.Valid {
email = user.Email.String
}
language := "N/A"
if user.PreferredLanguage.Valid {
language = user.PreferredLanguage.String
}
level := "N/A"
if user.CurrentLevel.Valid {
level = user.CurrentLevel.String
}
fmt.Printf("%-5d %-20s %-30s %-15s %-10s %-10s %-10s\n",
user.ID,
user.Username,
email,
language,
level,
aiEnabled,
user.CreatedAt.Format("2006-01-02"),
)
}
logger.Info(ctx, "Listed users", map[string]interface{}{"total": len(users)})
return nil
}
}
// runResetPassword returns a function that resets a user's password
func runResetPassword(userService *services.UserService, logger *observability.Logger) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, args []string) error {
ctx := context.Background()
var username string
var newPassword string
// Get username from args or prompt
if len(args) > 0 {
username = args[0]
} else {
fmt.Print("Enter username: ")
if _, err := fmt.Scanln(&username); err != nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to read username: %v", err)
}
}
if username == "" {
return contextutils.ErrorWithContextf("username is required")
}
// Prompt for password securely
fmt.Print("Enter new password: ")
passwordBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to read password: %v", err)
}
newPassword = string(passwordBytes)
fmt.Println() // New line after password input
if newPassword == "" {
return contextutils.ErrorWithContextf("password cannot be empty")
}
// Confirm password
fmt.Print("Confirm new password: ")
confirmBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to read password confirmation: %v", err)
}
confirmPassword := string(confirmBytes)
fmt.Println() // New line after password input
if newPassword != confirmPassword {
return contextutils.ErrorWithContextf("passwords do not match")
}
logger.Info(ctx, "Resetting password for user", map[string]interface{}{
"username": username,
})
// Get user by username
user, err := userService.GetUserByUsername(ctx, username)
if err != nil {
logger.Error(ctx, "Failed to get user", err, map[string]interface{}{"username": username})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to get user '%s': %v", username, err)
}
if user == nil {
logger.Error(ctx, "User not found", nil, map[string]interface{}{"username": username})
return contextutils.ErrorWithContextf("user '%s' not found", username)
}
// Update the password
err = userService.UpdateUserPassword(ctx, user.ID, newPassword)
if err != nil {
logger.Error(ctx, "Failed to update password", err, map[string]interface{}{
"username": username,
"user_id": user.ID,
})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to update password for user '%s': %v", username, err)
}
fmt.Printf("â Password successfully reset for user '%s' (ID: %d)\n", username, user.ID)
logger.Info(ctx, "Password reset successful", map[string]interface{}{
"username": username,
"user_id": user.ID,
})
return nil
}
}
package commands
import (
"database/sql"
"fmt"
"strings"
)
// maskDatabaseURL masks sensitive parts of the database URL for display
func maskDatabaseURL(url string) string {
// Simple masking for display purposes
if strings.Contains(url, "@") {
parts := strings.Split(url, "@")
if len(parts) == 2 {
return "postgres://***:***@" + parts[1]
}
}
return url
}
// getDatabaseInfo returns database connection information
func getDatabaseInfo(db *sql.DB) string {
if db == nil {
return "Not connected"
}
// Try to get database name
var dbName string
err := db.QueryRow("SELECT current_database()").Scan(&dbName)
if err != nil {
return "Connected (unknown database)"
}
// Try to get host information
var host string
err = db.QueryRow("SELECT inet_server_addr()::text").Scan(&host)
if err != nil {
return fmt.Sprintf("Connected to %s", dbName)
}
return fmt.Sprintf("Connected to %s on %s", dbName, host)
}
// Package main provides the main entry point for the quiz application admin CLI tool.
package main
import (
"context"
"fmt"
"os"
"quizapp/cmd/adm/commands"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"quizapp/internal/services"
"github.com/spf13/cobra"
)
// Global variables for shared resources
var (
cfg *config.Config
logger *observability.Logger
userService *services.UserService
)
func main() {
ctx := context.Background()
// Set default config file if not already set
if os.Getenv("QUIZ_CONFIG_FILE") == "" {
// Try to find the config file in common locations
defaultPaths := []string{
"../merged.config.yaml", // From backend/cmd/adm/
"../../merged.config.yaml", // From backend/cmd/adm/ (alternative)
"merged.config.yaml", // Current directory
}
for _, path := range defaultPaths {
if _, err := os.Stat(path); err == nil {
if err := os.Setenv("QUIZ_CONFIG_FILE", path); err != nil {
fmt.Fprintf(os.Stderr, "Failed to set QUIZ_CONFIG_FILE environment variable: %v\n", err)
os.Exit(1)
}
break
}
}
}
// Load configuration
var err error
cfg, err = config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Override log level for admin tool
cfg.Server.LogLevel = "error"
// Disable all OpenTelemetry features for admin CLI to avoid connection errors
cfg.OpenTelemetry.EnableTracing = false
cfg.OpenTelemetry.EnableMetrics = false
cfg.OpenTelemetry.EnableLogging = false
// Setup observability (tracing/metrics/logging)
tp, mp, loggerInstance, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-admin")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
// Store logger globally
logger = loggerInstance
// Defer cleanup
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
// Initialize database manager
dbManager := database.NewManager(logger)
// Initialize database connection with configuration (no migrations for admin tool)
db, err := dbManager.InitDBWithoutMigrations(cfg.Database)
if err != nil {
logger.Error(ctx, "Failed to connect to database", err, map[string]interface{}{"db_url": cfg.Database.URL})
os.Exit(1)
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database connection", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService = services.NewUserServiceWithLogger(db, cfg, logger)
// Create the root command
rootCmd := &cobra.Command{
Use: "adm",
Short: "Quiz Application Administration Tool",
Long: `Quiz Application Administration Tool
A comprehensive CLI tool for administering the quiz application.
Provides commands for user management, database operations, and system administration.`,
Run: func(cmd *cobra.Command, _ []string) {
// Show help if no subcommand provided
if err := cmd.Help(); err != nil {
fmt.Printf("Error showing help: %v\n", err)
}
},
}
// Add subcommands with initialized services
rootCmd.AddCommand(commands.UserCommands(userService, logger, cfg.Database.URL))
rootCmd.AddCommand(commands.DatabaseCommands(userService, logger, db))
rootCmd.AddCommand(commands.TranslationCommands(logger, db))
// Execute the command
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
// Package main provides a CLI tool for running the worker to generate questions for a specific user.
package main
import (
"context"
"flag"
"fmt"
"os"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/worker"
)
func main() {
ctx := context.Background()
// Define command line flags
var (
username = flag.String("username", "", "Username to generate questions for (required)")
level = flag.String("level", "", "Override user's current level (optional)")
language = flag.String("language", "", "Override user's preferred language (optional)")
questionType = flag.String("type", "vocabulary", "Question type: vocabulary, fill_blank, qa, reading_comprehension")
topic = flag.String("topic", "", "Specific topic for questions (optional)")
count = flag.Int("count", 5, "Number of questions to generate")
aiProvider = flag.String("ai-provider", "", "Override AI provider (optional)")
aiModel = flag.String("ai-model", "", "Override AI model (optional)")
aiAPIKey = flag.String("ai-api-key", "", "Override AI API key (optional)")
help = flag.Bool("help", false, "Show help message")
)
flag.Parse()
if *help {
printUsage(nil)
return
}
if *username == "" {
fmt.Fprintln(os.Stderr, "Error: --username flag is required")
os.Exit(1)
}
// Load configuration
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-cli-worker")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error()})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error()})
}
}
}()
logger.Info(ctx, "Starting quiz CLI worker", map[string]interface{}{
"username": *username,
"question_type": *questionType,
"count": *count,
})
// Validate question type
validTypes := map[string]models.QuestionType{
"vocabulary": models.Vocabulary,
"fill_blank": models.FillInBlank,
"qa": models.QuestionAnswer,
"reading_comprehension": models.ReadingComprehension,
}
qType, valid := validTypes[strings.ToLower(*questionType)]
if !valid {
logger.Error(ctx, "Invalid question type", nil, map[string]interface{}{"question_type": *questionType})
fmt.Fprintf(os.Stderr, "Error: Invalid question type '%s'\n", *questionType)
os.Exit(1)
}
// Validate level if provided
if *level != "" {
if !isValidLevel(*level, cfg.GetAllLevels()) {
logger.Error(ctx, "Invalid level", nil, map[string]interface{}{"level": *level})
fmt.Fprintf(os.Stderr, "Error: Invalid level '%s'\n", *level)
os.Exit(1)
}
}
// Validate language if provided (use dynamic list from config)
validLanguages := cfg.GetLanguages()
if *language != "" {
if !isValidLanguage(*language, validLanguages) {
logger.Error(ctx, "Invalid language", nil, map[string]interface{}{"language": *language})
fmt.Fprintf(os.Stderr, "Error: Invalid language '%s'\n", *language)
os.Exit(1)
}
}
// Initialize database manager with logger
dbManager := database.NewManager(logger)
// Initialize database connection with configuration
db, err := dbManager.InitDBWithoutMigrations(cfg.Database)
if err != nil {
logger.Error(ctx, "Failed to connect to database", err, map[string]interface{}{"db_url": cfg.Database.URL})
fmt.Fprintf(os.Stderr, "Failed to connect to database: %v\n", err)
os.Exit(1)
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database connection", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
learningService := services.NewLearningServiceWithLogger(db, cfg, logger)
// Create question service
questionService := services.NewQuestionServiceWithLogger(db, learningService, cfg, logger)
// Create usage stats service
usageStatsService := services.NewUsageStatsService(cfg, db, logger)
aiService := services.NewAIService(cfg, logger, usageStatsService)
workerService := services.NewWorkerServiceWithLogger(db, logger)
// Get user by username
user, err := userService.GetUserByUsername(ctx, *username)
if err != nil {
logger.Error(ctx, "Failed to get user", err)
fmt.Fprintf(os.Stderr, "Failed to get user: %v\n", err)
os.Exit(1)
}
if user == nil {
logger.Error(ctx, "User not found", nil, map[string]interface{}{"username": *username})
fmt.Fprintf(os.Stderr, "User not found: %s\n", *username)
os.Exit(1)
return
}
logger.Info(ctx, "Found user", map[string]interface{}{"username": user.Username, "user_id": user.ID})
// Apply AI overrides if provided
if *aiProvider != "" {
user.AIProvider.String = *aiProvider
user.AIProvider.Valid = true
user.AIEnabled.Bool = true
user.AIEnabled.Valid = true
}
if *aiModel != "" {
user.AIModel.String = *aiModel
user.AIModel.Valid = true
}
if *aiAPIKey != "" {
// Set AI provider and API key if provided
if *aiProvider != "" && *aiAPIKey != "" {
if err := userService.SetUserAPIKey(ctx, user.ID, *aiProvider, *aiAPIKey); err != nil {
logger.Error(ctx, "Failed to set API key", err)
fmt.Fprintf(os.Stderr, "Failed to set API key: %v\n", err)
os.Exit(1)
}
} else if *aiAPIKey != "" {
// If only API key is provided, use the user's current AI provider
if err := userService.SetUserAPIKey(ctx, user.ID, user.AIProvider.String, *aiAPIKey); err != nil {
logger.Error(ctx, "Failed to set API key", err)
fmt.Fprintf(os.Stderr, "Failed to set API key: %v\n", err)
os.Exit(1)
}
}
}
// Check if user has AI enabled (after potential overrides)
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
logger.Warn(ctx, "User does not have AI enabled", map[string]interface{}{"username": user.Username, "user_id": user.ID})
logger.Info(ctx, "You may want to enable AI for this user first or use --ai-provider flag")
}
// Determine language and level to use
languageToUse := user.PreferredLanguage.String
if *language != "" {
languageToUse = *language
}
levelToUse := user.CurrentLevel.String
if *level != "" {
levelToUse = *level
}
// Validate that we have required settings
if languageToUse == "" {
logger.Error(ctx, "No language specified", nil, map[string]interface{}{"username": user.Username, "user_id": user.ID})
fmt.Fprintln(os.Stderr, "Error: No language specified. User has no preferred language and --language flag not provided")
os.Exit(1)
}
if levelToUse == "" {
logger.Error(ctx, "No level specified", nil, map[string]interface{}{"username": user.Username, "user_id": user.ID})
fmt.Fprintln(os.Stderr, "Error: No level specified. User has no current level and --level flag not provided")
os.Exit(1)
}
// Print configuration
fmt.Printf("=== CLI Worker Configuration ===\n")
fmt.Printf("User: %s (ID: %d)\n", user.Username, user.ID)
fmt.Printf("Language: %s\n", languageToUse)
fmt.Printf("Level: %s\n", levelToUse)
fmt.Printf("Question Type: %s\n", qType)
fmt.Printf("Count: %d\n", *count)
if *topic != "" {
fmt.Printf("Topic: %s\n", *topic)
}
if user.AIProvider.Valid && user.AIProvider.String != "" {
fmt.Printf("AI Provider: %s\n", user.AIProvider.String)
}
if user.AIModel.Valid && user.AIModel.String != "" {
fmt.Printf("AI Model: %s\n", user.AIModel.String)
}
fmt.Printf("===============================\n\n")
// Create email service
emailService := services.CreateEmailService(cfg, logger)
// Create daily question service
dailyQuestionService := services.NewDailyQuestionService(db, logger, questionService, learningService)
// Create story service
storyService := services.NewStoryService(db, cfg, logger)
// Create word of the day service
wordOfTheDayService := services.NewWordOfTheDayService(db, logger)
// Create translation cache repository
translationCacheRepo := services.NewTranslationCacheRepository(db, logger)
// Create a minimal worker instance for question generation
workerInstance := worker.NewWorker(userService, questionService, aiService, learningService, workerService, dailyQuestionService, wordOfTheDayService, storyService, emailService, nil, translationCacheRepo, "cli", cfg, logger)
// Create context with timeout
ctx, cancel := context.WithTimeout(ctx, config.CLIWorkerTimeout)
defer cancel()
// Log CLI worker start with structured logging
logger.Info(ctx, "CLI worker starting question generation", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"question_type": qType,
"count": *count,
"language": languageToUse,
"level": levelToUse,
})
// Generate questions
fmt.Printf("Starting question generation...\n")
startTime := time.Now()
result, err := workerInstance.GenerateQuestionsForUser(ctx, user, languageToUse, levelToUse, qType, *count, *topic)
duration := time.Since(startTime)
if err != nil {
fmt.Printf("\nâ Question generation failed after %v\n", duration)
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
fmt.Printf("\nâ Question generation completed successfully in %v\n", duration)
fmt.Printf("Result: %s\n", result)
}
func isValidLevel(level string, validLevels []string) bool {
for _, validLevel := range validLevels {
if strings.EqualFold(level, validLevel) {
return true
}
}
return false
}
func isValidLanguage(language string, validLanguages []string) bool {
for _, validLang := range validLanguages {
if strings.EqualFold(language, validLang) {
return true
}
}
return false
}
func printUsage(cfg *config.Config) {
if cfg == nil {
fmt.Fprintf(os.Stderr, "Error: Configuration is missing or invalid.\n")
return
}
fmt.Printf("Usage: cli-worker [flags]\n")
fmt.Printf("Flags:\n")
fmt.Printf(" -language string\tLanguage to generate questions for\n")
fmt.Printf(" -level string\tLevel to generate questions for\n")
fmt.Printf(" -type string\tQuestion type (vocabulary, fill_in_blank, qa, reading_comprehension)\n")
fmt.Printf(" -count int\tNumber of questions to generate (default 1)\n")
fmt.Printf(" -topic string\tTopic for question generation\n")
fmt.Printf(" -provider string\tAI provider to use\n")
fmt.Printf(" -model string\tAI model to use\n")
fmt.Printf(" -help\tShow this help message\n\n")
fmt.Printf("Valid levels: %s\n", strings.Join(cfg.GetAllLevels(), ", "))
fmt.Printf("Valid languages: %s\n", strings.Join(cfg.GetLanguages(), ", "))
if cfg.Providers != nil {
providerNames := make([]string, 0, len(cfg.Providers))
for _, p := range cfg.Providers {
providerNames = append(providerNames, p.Code)
}
fmt.Printf("Valid providers: %s\n", strings.Join(providerNames, ", "))
} else {
fmt.Printf("Valid providers: \n")
}
}
// Package main provides a small CLI utility to reset the application's
// database to a clean state. It is intended for local development and
// testing only and will permanently delete all data when run.
package main
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"quizapp/internal/services"
)
// fatalIfErr logs the error with context and exits
func fatalIfErr(ctx context.Context, logger *observability.Logger, msg string, err error, fields map[string]interface{}) {
logger.Error(ctx, msg, err, fields)
os.Exit(1)
}
func main() {
ctx := context.Background()
// Load configuration first
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "reset-db")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
fmt.Println("âï DATABASE RESET UTILITY âï")
fmt.Println("=============================")
fmt.Println("This will PERMANENTLY DELETE ALL DATA in the database!")
fmt.Println("This includes:")
fmt.Println("- All users (including admin)")
fmt.Println("- All questions")
fmt.Println("- All user responses")
fmt.Println("- All performance metrics")
fmt.Println("")
logger.Info(ctx, "Attempting to reset the database", map[string]interface{}{"service": "reset-db"})
if cfg.Database.URL == "" {
fatalIfErr(ctx, logger, "Database URL is empty", nil, map[string]interface{}{"error": "Database URL is empty. Cannot proceed with reset."})
}
// Print database info
fmt.Println("ð Database Information:")
fmt.Printf("URL: %s\n", maskDatabaseURL(cfg.Database.URL))
fmt.Println("")
// Confirm with user
if !confirmReset() {
fmt.Println("Reset cancelled.")
return
}
// Initialize database manager with logger
dbManager := database.NewManager(logger)
// Initialize database connection with configuration
db, err := dbManager.InitDBWithConfig(cfg.Database)
if err != nil {
fatalIfErr(ctx, logger, "Failed to connect to database", err, map[string]interface{}{"db_url": cfg.Database.URL})
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database connection", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
// Drop all tables
fmt.Println("ðï Dropping all tables...")
logger.Info(ctx, "Dropping all tables", map[string]interface{}{"db_url": cfg.Database.URL, "service": "reset-db"})
// For now, we'll just run migrations which will recreate the schema
// In a real implementation, you might want to add a DropAllTables method to the database manager
// Run migrations
fmt.Println("ð Running database migrations...")
logger.Info(ctx, "Running database migrations", map[string]interface{}{"db_url": cfg.Database.URL, "service": "reset-db"})
if err := dbManager.RunMigrations(db); err != nil {
fatalIfErr(ctx, logger, "Failed to run migrations", err, map[string]interface{}{"db_url": cfg.Database.URL})
}
fmt.Println("â Database migrations completed successfully!")
logger.Info(ctx, "Database migrations completed successfully", map[string]interface{}{"db_url": cfg.Database.URL, "service": "reset-db"})
// Recreate admin user immediately
fmt.Printf("Recreating admin user '%s'...\n", cfg.Server.AdminUsername)
logger.Info(ctx, "Recreating admin user", map[string]interface{}{"username": cfg.Server.AdminUsername, "service": "reset-db"})
// Ensure admin user exists
if err := userService.EnsureAdminUserExists(ctx, cfg.Server.AdminUsername, cfg.Server.AdminPassword); err != nil {
fatalIfErr(ctx, logger, "Failed to ensure admin user exists", err, map[string]interface{}{"admin_username": cfg.Server.AdminUsername})
}
fmt.Println("â Admin user recreated successfully!")
logger.Info(ctx, "Admin user recreated successfully", map[string]interface{}{"username": cfg.Server.AdminUsername, "service": "reset-db"})
fmt.Println("")
// Print admin credentials
fmt.Printf("\nAdmin user credentials:\n")
fmt.Printf(" Username: %s\n", cfg.Server.AdminUsername)
fmt.Printf(" Password: %s\n", cfg.Server.AdminPassword)
fmt.Println("")
fmt.Println("â Database is now ready to use!")
fmt.Println("- You can now start the server or use the existing running instance")
fmt.Println("- Use the credentials above to log into the application")
}
func confirmReset() bool {
reader := bufio.NewReader(os.Stdin)
for {
fmt.Print("Are you sure you want to reset the database? (type 'yes' to confirm): ")
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println("Error reading input:", err)
continue
}
response = strings.TrimSpace(strings.ToLower(response))
switch response {
case "yes":
return true
case "no", "":
return false
default:
fmt.Println("Please type 'yes' to confirm or 'no' to cancel.")
}
}
}
func maskDatabaseURL(url string) string {
// Simple masking for display purposes
if strings.Contains(url, "@") {
parts := strings.Split(url, "@")
if len(parts) == 2 {
return "postgres://***:***@" + parts[1]
}
}
return url
}
// Package main provides the main entry point for the quiz application backend server.
// It sets up the HTTP server, database connections, middleware, and API routes.
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"quizapp/internal/config"
"quizapp/internal/di"
"quizapp/internal/handlers"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// Application encapsulates the main application logic and can be tested
type Application struct {
container di.ServiceContainerInterface
router *gin.Engine
}
// NewApplication creates a new application instance
func NewApplication(container di.ServiceContainerInterface) (*Application, error) {
// Get services from container
userService, err := container.GetUserService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user service")
}
questionService, err := container.GetQuestionService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get question service")
}
learningService, err := container.GetLearningService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get learning service")
}
aiService, err := container.GetAIService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get AI service")
}
workerService, err := container.GetWorkerService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get worker service")
}
dailyQuestionService, err := container.GetDailyQuestionService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get daily question service")
}
storyService, err := container.GetStoryService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get story service")
}
oauthService, err := container.GetOAuthService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get OAuth service")
}
generationHintService, err := container.GetGenerationHintService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get generation hint service")
}
conversationService, err := container.GetConversationService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get conversation service")
}
translationService, err := container.GetTranslationService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get translation service")
}
snippetsService, err := container.GetSnippetsService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get snippets service")
}
usageStatsService, err := container.GetUsageStatsService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get usage stats service")
}
wordOfTheDayService, err := container.GetWordOfTheDayService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get word of the day service")
}
authAPIKeyService, err := container.GetAuthAPIKeyService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get auth API key service")
}
// Use the router factory
router := handlers.NewRouter(
container.GetConfig(),
userService,
questionService,
learningService,
aiService,
workerService,
dailyQuestionService,
storyService,
conversationService,
oauthService,
generationHintService,
translationService,
snippetsService,
usageStatsService,
wordOfTheDayService,
authAPIKeyService,
container.GetLogger(),
)
return &Application{
container: container,
router: router,
}, nil
}
// Run starts the application and returns an error if it fails to start
func (a *Application) Run(ctx context.Context, port string) error {
// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
if err := a.router.Run(":" + port); err != nil {
serverErr <- err
}
}()
// Wait for shutdown signal or server error
select {
case <-ctx.Done():
return nil // Context cancelled, graceful shutdown
case err := <-serverErr:
return contextutils.WrapError(err, "server failed")
}
}
// Shutdown gracefully shuts down the application
func (a *Application) Shutdown(ctx context.Context) error {
return a.container.Shutdown(ctx)
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup graceful shutdown
shutdownCh := make(chan os.Signal, 1)
signal.Notify(shutdownCh, syscall.SIGINT, syscall.SIGTERM)
// Load configuration
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-backend")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if tp != nil {
if err := tp.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
logger.Info(ctx, "Starting quiz backend service", map[string]interface{}{
"port": cfg.Server.Port,
"logLevel": cfg.Server.LogLevel,
})
// Initialize dependency injection container
container := di.NewServiceContainer(cfg, logger)
// Initialize all services
if err := container.Initialize(ctx); err != nil {
logger.Error(ctx, "Failed to initialize services", err, nil)
os.Exit(1)
}
// Ensure admin user exists
if err := container.EnsureAdminUser(ctx); err != nil {
logger.Error(ctx, "Failed to ensure admin user exists", err, map[string]interface{}{"admin_username": cfg.Server.AdminUsername})
os.Exit(1)
}
// Create application instance
app, err := NewApplication(container)
if err != nil {
logger.Error(ctx, "Failed to create application", err, nil)
os.Exit(1)
}
// Start application in a goroutine
appErr := make(chan error, 1)
go func() {
if err := app.Run(ctx, cfg.Server.Port); err != nil {
appErr <- err
}
}()
// Wait for shutdown signal or application error
select {
case <-shutdownCh:
logger.Info(ctx, "Received shutdown signal, shutting down gracefully", nil)
case err := <-appErr:
logger.Error(ctx, "Application failed", err, nil)
os.Exit(1)
}
// Graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
// Shutdown application
if err := app.Shutdown(shutdownCtx); err != nil {
logger.Error(ctx, "Error during application shutdown", err, nil)
os.Exit(1)
}
logger.Info(ctx, "Shutdown completed successfully", nil)
}
// Package main provides a utility to set up the test database with initial data.
package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"go.uber.org/zap/zapcore"
"gopkg.in/yaml.v3"
)
// TestUser represents a user in the test data files
type TestUser struct {
Username string `yaml:"username"`
Email string `yaml:"email"`
Password string `yaml:"password"` // Special field for password creation
PreferredLanguage string `yaml:"preferred_language"`
CurrentLevel string `yaml:"current_level"`
AIProvider string `yaml:"ai_provider"`
AIModel string `yaml:"ai_model"`
AIAPIKey string `yaml:"ai_api_key"`
Roles []string `yaml:"roles"`
}
// TestUsers represents a collection of test users
type TestUsers struct {
Users []TestUser `yaml:"users"`
}
// TestQuestions represents a collection of test questions
type TestQuestions struct {
Questions []models.Question `yaml:"questions"`
}
// TestResponses represents a collection of test user responses
type TestResponses struct {
UserResponses []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
UserAnswer string `yaml:"user_answer"`
IsCorrect bool `yaml:"is_correct"`
ResponseTimeMs int `yaml:"response_time_ms"`
} `yaml:"user_responses"`
QuestionReports []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
ReportReason string `yaml:"report_reason"`
CreatedAt *string `yaml:"created_at"`
} `yaml:"question_reports"`
}
// TestAnalytics represents analytics test data
type TestAnalytics struct {
PriorityScores []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
PriorityScore float64 `yaml:"priority_score"`
LastCalculatedAt string `yaml:"last_calculated_at"`
} `yaml:"priority_scores"`
LearningPreferences []struct {
Username string `yaml:"username"`
FocusOnWeakAreas bool `yaml:"focus_on_weak_areas"`
FreshQuestionRatio float64 `yaml:"fresh_question_ratio"`
WeakAreaBoost float64 `yaml:"weak_area_boost"`
KnownQuestionPenalty float64 `yaml:"known_question_penalty"`
ReviewIntervalDays int `yaml:"review_interval_days"`
DailyReminderEnabled bool `yaml:"daily_reminder_enabled"`
} `yaml:"learning_preferences"`
PerformanceMetrics []struct {
Username string `yaml:"username"`
Topic string `yaml:"topic"`
Language string `yaml:"language"`
Level string `yaml:"level"`
TotalAttempts int `yaml:"total_attempts"`
CorrectAttempts int `yaml:"correct_attempts"`
AverageResponseTimeMs float64 `yaml:"average_response_time_ms"`
} `yaml:"performance_metrics"`
UserQuestionMetadata []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
MarkedAsKnown bool `yaml:"marked_as_known"`
MarkedAsKnownAt *string `yaml:"marked_as_known_at"`
} `yaml:"user_question_metadata"`
}
// TestDailyAssignments represents the structure for daily question assignments in test data
type TestDailyAssignments struct {
DailyAssignments []struct {
Username string `yaml:"username"`
Date string `yaml:"date"`
QuestionIDs []int `yaml:"question_ids"`
CompletedQuestions []int `yaml:"completed_questions"`
} `yaml:"daily_assignments"`
}
// TestMessageData represents message data for E2E tests
type TestMessageData struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
Role string `json:"role"`
Content string `json:"content"`
Bookmarked bool `json:"bookmarked"`
QuestionID *int `json:"question_id,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// TestConversationData represents conversation data for E2E tests
type TestConversationData struct {
ID string `json:"id"`
Username string `json:"username"`
Title string `json:"title"`
Messages []TestMessageData `json:"messages"`
}
// TestConversations represents a collection of test conversations
type TestConversations struct {
Conversations []struct {
Username string `yaml:"username"`
Title string `yaml:"title"`
Messages []struct {
Role string `yaml:"role"`
Content string `yaml:"content"`
QuestionID *int `yaml:"question_id"`
} `yaml:"messages"`
} `yaml:"conversations"`
}
// TestStorySectionData represents section data for E2E tests
type TestStorySectionData struct {
ID int `json:"id"`
StoryID int `json:"story_id"`
SectionNumber int `json:"section_number"`
Content string `json:"content"`
LanguageLevel string `json:"language_level"`
WordCount int `json:"word_count"`
GeneratedBy string `json:"generated_by"`
}
// TestStoryData represents story data for E2E tests
type TestStoryData struct {
ID int `json:"id"`
Username string `json:"username"`
Title string `json:"title"`
Status string `json:"status"`
Sections []TestStorySectionData `json:"sections"`
}
// TestStories represents a collection of test stories
type TestStories struct {
Stories []struct {
Username string `yaml:"username"`
Title string `yaml:"title"`
Language string `yaml:"language"`
Subject *string `yaml:"subject"`
AuthorStyle *string `yaml:"author_style"`
TimePeriod *string `yaml:"time_period"`
Genre *string `yaml:"genre"`
Tone *string `yaml:"tone"`
CharacterNames *string `yaml:"character_names"`
CustomInstructions *string `yaml:"custom_instructions"`
SectionLengthOverride *string `yaml:"section_length_override"`
Status string `yaml:"status"`
IsCurrent bool `yaml:"is_current"`
Sections []struct {
SectionNumber int `yaml:"section_number"`
Content string `yaml:"content"`
LanguageLevel string `yaml:"language_level"`
WordCount int `yaml:"word_count"`
GeneratedBy string `yaml:"generated_by"`
Questions []struct {
QuestionText string `yaml:"question_text"`
Options []string `yaml:"options"`
CorrectAnswerIndex int `yaml:"correct_answer_index"`
Explanation *string `yaml:"explanation"`
} `yaml:"questions"`
} `yaml:"sections"`
} `yaml:"stories"`
}
// TestSnippetData represents snippet data for E2E tests
type TestSnippetData struct {
ID int `json:"id"`
Username string `json:"username"`
OriginalText string `json:"original_text"`
TranslatedText string `json:"translated_text"`
SourceLanguage string `json:"source_language"`
TargetLanguage string `json:"target_language"`
}
// TestSnippets represents a collection of test snippets
type TestSnippets struct {
Snippets []struct {
Username string `yaml:"username"`
OriginalText string `yaml:"original_text"`
TranslatedText string `yaml:"translated_text"`
SourceLanguage string `yaml:"source_language"`
TargetLanguage string `yaml:"target_language"`
Context *string `yaml:"context"`
DifficultyLevel string `yaml:"difficulty_level"`
} `yaml:"snippets"`
}
// TestFeedbackData represents feedback data for E2E tests
type TestFeedbackData struct {
ID int `json:"id"`
Username string `json:"username"`
FeedbackText string `json:"feedback_text"`
FeedbackType string `json:"feedback_type"`
Status string `json:"status"`
ContextData map[string]interface{} `json:"context_data"`
}
// TestFeedback represents a collection of test feedback
type TestFeedback struct {
FeedbackReports []struct {
Username string `yaml:"username"`
FeedbackText string `yaml:"feedback_text"`
FeedbackType string `yaml:"feedback_type"`
Status string `yaml:"status"`
ContextData map[string]interface{} `yaml:"context_data"`
} `yaml:"feedback_reports"`
}
func resetTestDatabase(databaseURL, testDB string, logger *observability.Logger) error {
ctx := context.Background()
// Create admin connection string by replacing the database name with 'postgres'
// This connects to the admin database to drop/create the test database
adminConnStr := strings.Replace(databaseURL, "/"+testDB+"?", "/postgres?", 1)
if !strings.Contains(adminConnStr, "/postgres?") {
// Handle case where there's no query string
adminConnStr = strings.Replace(databaseURL, "/"+testDB, "/postgres", 1)
}
logger.Info(ctx, "Connecting to admin database", map[string]interface{}{"connection_string": adminConnStr})
adminDB, err := sql.Open("postgres", adminConnStr)
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseConnection, "failed to connect to postgres database for drop/create: %v", err)
}
defer func() {
if err := adminDB.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close adminDB", map[string]interface{}{"error": err.Error()})
}
}()
logger.Info(ctx, "Terminating connections to test DB", map[string]interface{}{"database": testDB})
_, err = adminDB.Exec(fmt.Sprintf(`
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '%s' AND pid <> pg_backend_pid();
`, testDB))
if err != nil {
logger.Warn(ctx, "Warning: failed to terminate connections", map[string]interface{}{"error": err.Error()})
}
logger.Info(ctx, "Dropping test database", map[string]interface{}{"database": testDB})
_, err = adminDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE);", testDB))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to drop test database: %v", err)
}
logger.Info(ctx, "Successfully dropped test database", map[string]interface{}{"database": testDB})
logger.Info(ctx, "Creating test database", map[string]interface{}{"database": testDB})
_, err = adminDB.Exec(fmt.Sprintf("CREATE DATABASE %s;", testDB))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to create test database: %v", err)
}
logger.Info(ctx, "Successfully created test database", map[string]interface{}{"database": testDB})
logger.Info(ctx, "Test database reset complete")
return nil
}
func main() {
ctx := context.Background()
// CLI flags
verbose := flag.Bool("verbose", false, "enable verbose logging")
flag.Parse()
// Load configuration first
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load config: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics). Suppress logger creation here to avoid startup noise.
originalLogging := cfg.OpenTelemetry.EnableLogging
cfg.OpenTelemetry.EnableLogging = false
tp, mp, _, err := observability.SetupObservability(&cfg.OpenTelemetry, "setup-test-db")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
// Create logger with level based on --verbose flag
logLevel := zapcore.WarnLevel
if *verbose {
logLevel = zapcore.InfoLevel
}
// Restore config flag for logger construction (to allow OTLP exporter if enabled)
cfg.OpenTelemetry.EnableLogging = originalLogging
logger := observability.NewLoggerWithLevel(&cfg.OpenTelemetry, logLevel)
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error()})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error()})
}
}
}()
// Get DB connection info from env or use defaults
dbUser := "quiz_user"
dbPassword := "quiz_password"
dbHost := "localhost"
dbPort := "5433"
testDB := "quiz_test_db"
// Allow override from DATABASE_URL
databaseURL := os.Getenv("DATABASE_URL")
if databaseURL == "" {
databaseURL = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dbUser, dbPassword, dbHost, dbPort, testDB)
}
// Debug: Print the DATABASE_URL we're using
logger.Info(ctx, "DATABASE_URL from environment", map[string]interface{}{"database_url": os.Getenv("DATABASE_URL")})
logger.Info(ctx, "Using database URL", map[string]interface{}{"database_url": databaseURL})
// --- Drop and recreate the test database ---
if err := resetTestDatabase(databaseURL, testDB, logger); err != nil {
logger.Error(ctx, "Failed to reset test database", err)
os.Exit(1)
}
// Now connect to the new test database
logger.Info(ctx, "Connecting to database", map[string]interface{}{"database_url": databaseURL})
// Initialize database manager with logger
dbManager := database.NewManager(logger)
db, err := dbManager.InitDB(databaseURL)
if err != nil {
logger.Error(ctx, "Failed to initialize database", err)
os.Exit(1)
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database", map[string]interface{}{"error": err.Error()})
}
}()
// Get the root directory (backend is the working directory)
rootDir, err := os.Getwd()
if err != nil {
logger.Error(ctx, "Failed to get working directory", err)
os.Exit(1)
}
// Apply schema from schema.sql
schemaPath := filepath.Join(rootDir, "..", "schema.sql")
if err := applySchema(db, schemaPath, rootDir, logger); err != nil {
logger.Error(ctx, "Failed to apply schema", err)
os.Exit(1)
}
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
learningService := services.NewLearningServiceWithLogger(db, cfg, logger)
// Create question service
questionService := services.NewQuestionServiceWithLogger(db, learningService, cfg, logger)
// Ensure admin user exists
if err := userService.EnsureAdminUserExists(ctx, "admin", "password"); err != nil {
logger.Error(ctx, "Failed to ensure admin user exists", err)
os.Exit(1)
}
// Load and insert test data
users, err := setupTestData(ctx, rootDir, userService, questionService, learningService, db, logger)
if err != nil {
logger.Error(ctx, "Failed to setup test data", err)
os.Exit(1)
}
// Output user data to JSON file for E2E tests
if err := outputUserDataForTests(users, rootDir, logger); err != nil {
logger.Error(ctx, "Failed to output user data for tests", err)
os.Exit(1)
}
// Output roles data to JSON file for E2E tests
if err := outputRolesDataForTests(db, rootDir, logger); err != nil {
logger.Error(ctx, "Failed to output roles data for tests", err)
os.Exit(1)
}
logger.Info(ctx, "Test database created successfully")
}
func applySchema(db *sql.DB, schemaPath, _ string, logger *observability.Logger) error {
ctx := context.Background()
// Apply the schema (database is already empty after resetTestDatabase)
logger.Info(ctx, "Applying schema")
schemaSQL, err := os.ReadFile(schemaPath)
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to read schema file: %v", err)
}
if _, err := db.Exec(string(schemaSQL)); err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to execute schema: %v", err)
}
// Priority system tables are already included in the main schema.sql
// No additional migration needed
logger.Info(ctx, "Priority system tables already included in main schema")
return nil
}
func setupTestData(ctx context.Context, rootDir string, userService *services.UserService, questionService *services.QuestionService, learningService *services.LearningService, db *sql.DB, logger *observability.Logger) (map[string]*models.User, error) {
dataDir := filepath.Join(rootDir, "data")
// 1. Load and create users
users, err := loadAndCreateUsers(ctx, filepath.Join(dataDir, "test_users.yaml"), userService, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup users: %v", err)
}
// 2. Load and create questions
questions, err := loadAndCreateQuestions(ctx, filepath.Join(dataDir, "test_questions.yaml"), questionService, users, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup questions: %v", err)
}
// 3. Load and create user responses
if err := loadAndCreateResponses(ctx, filepath.Join(dataDir, "test_responses.yaml"), users, questions, learningService, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup responses: %v", err)
}
// 4. Load and create question reports
if err := loadAndCreateQuestionReports(ctx, filepath.Join(dataDir, "test_responses.yaml"), users, questions, db, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup question reports: %v", err)
}
// 5. Load and create analytics data
if err := loadAndCreateAnalytics(ctx, filepath.Join(dataDir, "test_analytics.yaml"), users, questions, learningService, db, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup analytics: %v", err)
}
// 6. Load and create daily assignments
if err := loadAndCreateDailyAssignments(ctx, filepath.Join(dataDir, "test_daily_assignments.yaml"), users, questions, db, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup daily assignments: %v", err)
}
// 7. Load and create stories
stories, err := loadAndCreateStories(ctx, filepath.Join(dataDir, "test_stories.yaml"), users, db, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup stories: %v", err)
}
// Output story data for E2E tests
if err := outputStoryDataForTests(stories, rootDir, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to output story data: %v", err)
}
// 8. Load and create snippets
snippets, err := loadAndCreateSnippets(ctx, filepath.Join(dataDir, "test_snippets.yaml"), users, db, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup snippets: %v", err)
}
// Output snippet data for E2E tests
if err := outputSnippetDataForTests(snippets, rootDir, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to output snippet data: %v", err)
}
// 9. Load and create conversations
conversations, err := loadAndCreateConversations(ctx, filepath.Join(dataDir, "test_conversations.yaml"), users, db, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup conversations: %v", err)
}
// Output conversation data for E2E tests
if err := outputConversationDataForTests(conversations, rootDir, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to output conversation data: %v", err)
}
// 10. Load and create feedback reports
feedback, err := loadAndCreateFeedback(ctx, filepath.Join(dataDir, "test_feedback.yaml"), users, db, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup feedback: %v", err)
}
// Output feedback data for E2E tests
if err := outputFeedbackDataForTests(feedback, rootDir, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to output feedback data: %v", err)
}
// 11. Create API Keys for test users
if err := createAndOutputAPIKeysForTests(ctx, users, db, rootDir, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup api keys: %v", err)
}
return users, nil
}
// TestAPIKeyData represents API key data for E2E tests (non-sensitive)
type TestAPIKeyData struct {
ID int `json:"id"`
Username string `json:"username"`
KeyName string `json:"key_name"`
KeyPrefix string `json:"key_prefix"`
PermissionLevel string `json:"permission_level"`
CreatedAt time.Time `json:"created_at"`
}
// createAndOutputAPIKeysForTests creates API keys for selected users and writes a JSON artifact for tests
func createAndOutputAPIKeysForTests(ctx context.Context, users map[string]*models.User, db *sql.DB, rootDir string, logger *observability.Logger) error {
// Initialize service
apiKeyService := services.NewAuthAPIKeyService(db, logger)
// Strategy:
// - apitestuser: 2 keys (readonly, full)
// - apitestadmin: 2 keys (readonly, full)
// - others: 1 readonly key
// Helper to create a key and capture minimal info
create := func(username string, userID int, keyName, perm string) (*TestAPIKeyData, error) {
key, _, err := apiKeyService.CreateAPIKey(ctx, userID, keyName, perm)
if err != nil {
return nil, err
}
return &TestAPIKeyData{
ID: key.ID,
Username: username,
KeyName: key.KeyName,
KeyPrefix: key.KeyPrefix,
PermissionLevel: key.PermissionLevel,
CreatedAt: key.CreatedAt,
}, nil
}
apiKeys := make(map[string]TestAPIKeyData)
for username, user := range users {
if username == "apitestuser" || username == "apitestadmin" {
if d, err := create(username, user.ID, "test_key_readonly", string(models.PermissionLevelReadonly)); err == nil {
apiKeys[fmt.Sprintf("%s_ro", username)] = *d
} else {
return contextutils.WrapErrorf(err, "failed creating readonly api key for %s", username)
}
if d, err := create(username, user.ID, "test_key_full", string(models.PermissionLevelFull)); err == nil {
apiKeys[fmt.Sprintf("%s_full", username)] = *d
} else {
return contextutils.WrapErrorf(err, "failed creating full api key for %s", username)
}
} else {
if d, err := create(username, user.ID, "test_key_readonly", string(models.PermissionLevelReadonly)); err == nil {
apiKeys[fmt.Sprintf("%s_ro", username)] = *d
} else {
return contextutils.WrapErrorf(err, "failed creating readonly api key for %s", username)
}
}
}
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-api-keys.json")
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
jsonData, err := json.MarshalIndent(apiKeys, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal api keys data to JSON")
}
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write api keys data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output API keys data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"keys_count": len(apiKeys),
})
return nil
}
func loadAndCreateUsers(ctx context.Context, filePath string, userService *services.UserService, logger *observability.Logger) (result0 map[string]*models.User, err error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var testUsers TestUsers
if err := yaml.Unmarshal(data, &testUsers); err != nil {
return nil, err
}
users := make(map[string]*models.User)
for _, testUser := range testUsers.Users {
// Create user with email and timezone
user, err := userService.CreateUserWithEmailAndTimezone(
ctx,
testUser.Username,
testUser.Email,
"UTC", // Default timezone for test users
testUser.PreferredLanguage,
testUser.CurrentLevel,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to create user %s", testUser.Username)
}
// Set password separately since CreateUserWithEmailAndTimezone doesn't set password
if err := userService.UpdateUserPassword(ctx, user.ID, testUser.Password); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to set password for user %s", testUser.Username)
}
// Update additional settings
settings := &models.UserSettings{
Language: testUser.PreferredLanguage,
Level: testUser.CurrentLevel,
AIProvider: testUser.AIProvider,
AIModel: testUser.AIModel,
AIAPIKey: testUser.AIAPIKey,
AIEnabled: testUser.AIProvider != "", // Enable AI if provider is set
}
if err := userService.UpdateUserSettings(ctx, user.ID, settings); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to update settings for user %s", testUser.Username)
}
// Assign roles from YAML configuration
for _, roleName := range testUser.Roles {
err = userService.AssignRoleByName(ctx, user.ID, roleName)
if err != nil {
logger.Warn(ctx, "Failed to assign role to user", map[string]interface{}{
"username": testUser.Username,
"role": roleName,
"error": err.Error(),
})
} else {
logger.Info(ctx, "Assigned role to user", map[string]interface{}{
"username": testUser.Username,
"role": roleName,
"user_id": user.ID,
})
}
}
users[testUser.Username] = user
}
return users, nil
}
func loadAndCreateQuestions(ctx context.Context, filePath string, questionService *services.QuestionService, users map[string]*models.User, _ *observability.Logger) (result0 []*models.Question, err error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var testQuestions TestQuestions
if err := yaml.Unmarshal(data, &testQuestions); err != nil {
return nil, err
}
var questions []*models.Question
for i, question := range testQuestions.Questions {
// Set the created time since it's not in YAML
question.CreatedAt = time.Now()
// Get the users this question should be assigned to
questionUsers := question.Users
var assignedUserIDs []int
if len(questionUsers) == 0 {
// Fallback to round-robin if no users specified
for _, user := range users {
assignedUserIDs = append(assignedUserIDs, user.ID)
}
if len(assignedUserIDs) == 0 {
return nil, contextutils.ErrorWithContextf("no users available to assign questions to")
}
// Assign to one user in round-robin
assignedUserIDs = []int{assignedUserIDs[i%len(assignedUserIDs)]}
} else {
for _, username := range questionUsers {
user, exists := users[username]
if !exists {
return nil, contextutils.ErrorWithContextf("user not found: %s", username)
}
assignedUserIDs = append(assignedUserIDs, user.ID)
}
}
if err := questionService.SaveQuestion(ctx, &question); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to save question %d", i)
}
for _, userID := range assignedUserIDs {
if err := questionService.AssignQuestionToUser(ctx, question.ID, userID); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to assign question %d to user %d", question.ID, userID)
}
}
questions = append(questions, &question)
}
return questions, nil
}
func loadAndCreateResponses(_ context.Context, filePath string, users map[string]*models.User, questions []*models.Question, learningService *services.LearningService, _ *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
return err
}
var testResponses TestResponses
if err := yaml.Unmarshal(data, &testResponses); err != nil {
return err
}
for i, responseData := range testResponses.UserResponses {
user, exists := users[responseData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found: %s", responseData.Username)
}
if responseData.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range: %d", responseData.QuestionIndex)
}
question := questions[responseData.QuestionIndex]
// Use RecordAnswerWithPriority to ensure priority scores are calculated
if err := learningService.RecordAnswerWithPriority(
context.Background(),
user.ID,
question.ID,
0, // Use index 0 for test data
responseData.IsCorrect,
responseData.ResponseTimeMs,
); err != nil {
return contextutils.WrapErrorf(err, "failed to record response %d", i)
}
}
return nil
}
func loadAndCreateQuestionReports(_ context.Context, filePath string, users map[string]*models.User, questions []*models.Question, db *sql.DB, _ *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
return contextutils.WrapError(err, "failed to read responses file")
}
var testResponses TestResponses
if err := yaml.Unmarshal(data, &testResponses); err != nil {
return contextutils.WrapError(err, "failed to parse responses data")
}
// Load question reports
for i, reportData := range testResponses.QuestionReports {
user, exists := users[reportData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for question report: %s", reportData.Username)
}
if reportData.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range for question report: %d", reportData.QuestionIndex)
}
question := questions[reportData.QuestionIndex]
// Parse the timestamp if provided, otherwise use current time
var createdAt time.Time
if reportData.CreatedAt != nil {
var err error
createdAt, err = time.Parse(time.RFC3339, *reportData.CreatedAt)
if err != nil {
return contextutils.ErrorWithContextf("invalid timestamp format for question report: %s", *reportData.CreatedAt)
}
} else {
createdAt = time.Now()
}
// Insert question report directly into database
_, err := db.Exec(`
INSERT INTO question_reports (question_id, reported_by_user_id, report_reason, created_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (question_id, reported_by_user_id) DO NOTHING
`, question.ID, user.ID, reportData.ReportReason, createdAt)
if err != nil {
return contextutils.WrapErrorf(err, "failed to insert question report %d", i)
}
}
return nil
}
func loadAndCreateAnalytics(ctx context.Context, filePath string, users map[string]*models.User, questions []*models.Question, learningService *services.LearningService, db *sql.DB, logger *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
// Analytics file is optional, so just return if it doesn't exist
logger.Warn(ctx, "Analytics file not found", map[string]interface{}{"file_path": filePath})
return nil
}
var testAnalytics TestAnalytics
if err := yaml.Unmarshal(data, &testAnalytics); err != nil {
return contextutils.WrapError(err, "failed to parse analytics data")
}
// Load priority scores
for _, priorityData := range testAnalytics.PriorityScores {
user, exists := users[priorityData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for priority score: %s", priorityData.Username)
}
if priorityData.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range for priority score: %d", priorityData.QuestionIndex)
}
question := questions[priorityData.QuestionIndex]
// Parse the timestamp
lastCalculatedAt, err := time.Parse(time.RFC3339, priorityData.LastCalculatedAt)
if err != nil {
return contextutils.ErrorWithContextf("invalid timestamp format for priority score: %s", priorityData.LastCalculatedAt)
}
// Insert priority score directly into database
_, err = db.Exec(`
INSERT INTO question_priority_scores (user_id, question_id, priority_score, last_calculated_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE SET
priority_score = EXCLUDED.priority_score,
last_calculated_at = EXCLUDED.last_calculated_at,
updated_at = NOW()
`, user.ID, question.ID, priorityData.PriorityScore, lastCalculatedAt)
if err != nil {
return contextutils.WrapError(err, "failed to insert priority score")
}
}
// Load learning preferences
for _, prefData := range testAnalytics.LearningPreferences {
user, exists := users[prefData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for learning preferences: %s", prefData.Username)
}
// Ensure daily_goal is present and valid. The schema enforces daily_goal > 0
// so default to the service's default if not provided or invalid.
dailyGoal := 0
// Try to parse a daily_goal field if it exists in the YAML by checking for a map
// fallback: the YAML struct doesn't include daily_goal currently; use default
// from the LearningService defaults.
// We'll fetch defaults from service to avoid duplicating magic numbers.
defaultPrefs := learningService.GetDefaultLearningPreferences()
if dailyGoal <= 0 {
dailyGoal = defaultPrefs.DailyGoal
}
prefs := &models.UserLearningPreferences{
UserID: user.ID,
FocusOnWeakAreas: prefData.FocusOnWeakAreas,
FreshQuestionRatio: prefData.FreshQuestionRatio,
WeakAreaBoost: prefData.WeakAreaBoost,
KnownQuestionPenalty: prefData.KnownQuestionPenalty,
ReviewIntervalDays: prefData.ReviewIntervalDays,
DailyReminderEnabled: prefData.DailyReminderEnabled,
DailyGoal: dailyGoal,
}
if _, err := learningService.UpdateUserLearningPreferences(ctx, user.ID, prefs); err != nil {
return contextutils.WrapErrorf(err, "failed to update learning preferences for user %s", prefData.Username)
}
}
// Load performance metrics
for _, metricData := range testAnalytics.PerformanceMetrics {
user, exists := users[metricData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for performance metrics: %s", metricData.Username)
}
// Insert performance metric directly into database
_, err := db.Exec(`
INSERT INTO performance_metrics (user_id, topic, language, level, total_attempts, correct_attempts, average_response_time_ms, last_updated)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (user_id, topic, language, level) DO UPDATE SET
total_attempts = EXCLUDED.total_attempts,
correct_attempts = EXCLUDED.correct_attempts,
average_response_time_ms = EXCLUDED.average_response_time_ms,
last_updated = NOW()
`, user.ID, metricData.Topic, metricData.Language, metricData.Level,
metricData.TotalAttempts, metricData.CorrectAttempts, metricData.AverageResponseTimeMs)
if err != nil {
return contextutils.WrapError(err, "failed to insert performance metric")
}
}
// Load user question metadata (marked as known)
for _, metadata := range testAnalytics.UserQuestionMetadata {
user, exists := users[metadata.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for question metadata: %s", metadata.Username)
}
if metadata.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range for metadata: %d", metadata.QuestionIndex)
}
question := questions[metadata.QuestionIndex]
if metadata.MarkedAsKnown {
var markedAt time.Time
if metadata.MarkedAsKnownAt != nil {
var err error
markedAt, err = time.Parse(time.RFC3339, *metadata.MarkedAsKnownAt)
if err != nil {
return contextutils.ErrorWithContextf("invalid timestamp format for marked as known: %s", *metadata.MarkedAsKnownAt)
}
} else {
markedAt = time.Now()
}
// Insert into user_question_metadata table
_, err := db.Exec(`
INSERT INTO user_question_metadata (user_id, question_id, marked_as_known, marked_as_known_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE SET
marked_as_known = EXCLUDED.marked_as_known,
marked_as_known_at = EXCLUDED.marked_as_known_at,
updated_at = NOW()
`, user.ID, question.ID, metadata.MarkedAsKnown, markedAt)
if err != nil {
return contextutils.WrapError(err, "failed to insert question metadata")
}
}
}
return nil
}
func loadAndCreateDailyAssignments(ctx context.Context, filePath string, users map[string]*models.User, questions []*models.Question, db *sql.DB, logger *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
// File doesn't exist, skip daily assignments
logger.Info(ctx, "Daily assignments file not found, skipping", map[string]interface{}{
"file_path": filePath,
})
return nil
}
var testDailyAssignments TestDailyAssignments
if err := yaml.Unmarshal(data, &testDailyAssignments); err != nil {
return err
}
for _, assignmentData := range testDailyAssignments.DailyAssignments {
user, exists := users[assignmentData.Username]
if !exists {
logger.Warn(ctx, "User not found for daily assignment", map[string]interface{}{
"username": assignmentData.Username,
})
continue
}
// Parse the date
date, err := time.Parse("2006-01-02", assignmentData.Date)
if err != nil {
logger.Warn(ctx, "Invalid date format for daily assignment", map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
})
continue
}
// Create a map of completed questions for quick lookup
completedQuestions := make(map[int]bool)
for _, qID := range assignmentData.CompletedQuestions {
completedQuestions[qID] = true
}
// Assign questions to the user for the specific date
for _, questionID := range assignmentData.QuestionIDs {
// Check if question exists
if questionID <= 0 || questionID > len(questions) {
logger.Warn(ctx, "Question ID out of range for daily assignment", map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"question_id": questionID,
})
continue
}
question := questions[questionID-1] // Convert to 0-based index
// Ensure we don't violate unique constraint by removing any existing assignment for the same
// (user_id, question_id, assignment_date) tuple before inserting. This avoids relying on
// ON CONFLICT which requires the constraint to be present in some test DB states.
deleteQuery := `DELETE FROM daily_question_assignments WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3`
if _, err := db.ExecContext(ctx, deleteQuery, user.ID, question.ID, date); err != nil {
logger.Error(ctx, "Failed to delete existing daily assignment", err, map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"question_id": questionID,
})
return contextutils.WrapErrorf(err, "failed to delete existing daily assignment for user %s, question %d", assignmentData.Username, questionID)
}
// Insert the assignment directly into the database
query := `
INSERT INTO daily_question_assignments (user_id, question_id, assignment_date, is_completed, completed_at)
VALUES ($1, $2, $3, $4, $5)
`
isCompleted := completedQuestions[questionID]
var completedAt *time.Time
if isCompleted {
now := time.Now()
completedAt = &now
}
if _, err := db.ExecContext(ctx, query, user.ID, question.ID, date, isCompleted, completedAt); err != nil {
logger.Error(ctx, "Failed to create daily assignment", err, map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"question_id": questionID,
})
return contextutils.WrapErrorf(err, "failed to create daily assignment for user %s, question %d", assignmentData.Username, questionID)
}
}
logger.Info(ctx, "Created daily assignments", map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"count": len(assignmentData.QuestionIDs),
})
}
return nil
}
func loadAndCreateStories(ctx context.Context, filePath string, users map[string]*models.User, db *sql.DB, logger *observability.Logger) (map[string]TestStoryData, error) {
stories := make(map[string]TestStoryData)
data, err := os.ReadFile(filePath)
if err != nil {
// Stories file is optional, so just return if it doesn't exist
logger.Info(ctx, "Stories file not found, skipping", map[string]interface{}{
"file_path": filePath,
})
return stories, nil
}
var testStories TestStories
if err := yaml.Unmarshal(data, &testStories); err != nil {
return stories, contextutils.WrapError(err, "failed to parse stories data")
}
for i, storyData := range testStories.Stories {
user, exists := users[storyData.Username]
if !exists {
return stories, contextutils.ErrorWithContextf("user not found for story: %s", storyData.Username)
}
// Parse section length override if provided
var sectionLengthOverride *models.SectionLength
if storyData.SectionLengthOverride != nil {
switch *storyData.SectionLengthOverride {
case "short":
sl := models.SectionLengthShort
sectionLengthOverride = &sl
case "medium":
sl := models.SectionLengthMedium
sectionLengthOverride = &sl
case "long":
sl := models.SectionLengthLong
sectionLengthOverride = &sl
}
}
// Create story
story := &models.Story{
UserID: uint(user.ID),
Title: storyData.Title,
Language: storyData.Language,
Subject: storyData.Subject,
AuthorStyle: storyData.AuthorStyle,
TimePeriod: storyData.TimePeriod,
Genre: storyData.Genre,
Tone: storyData.Tone,
CharacterNames: storyData.CharacterNames,
CustomInstructions: storyData.CustomInstructions,
SectionLengthOverride: sectionLengthOverride,
Status: models.StoryStatus(storyData.Status),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// Insert story directly into database
_, err := db.Exec(`
INSERT INTO stories (user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
`, story.UserID, story.Title, story.Language, story.Subject, story.AuthorStyle, story.TimePeriod,
story.Genre, story.Tone, story.CharacterNames, story.CustomInstructions, story.SectionLengthOverride,
string(story.Status), story.CreatedAt, story.UpdatedAt)
if err != nil {
return stories, contextutils.WrapErrorf(err, "failed to insert story %d", i)
}
// Get the story ID (we need to query it back since we don't have RETURNING)
var storyID int
err = db.QueryRow(`
SELECT id FROM stories WHERE user_id = $1 AND title = $2 ORDER BY created_at DESC LIMIT 1
`, story.UserID, story.Title).Scan(&storyID)
if err != nil {
return stories, contextutils.WrapErrorf(err, "failed to get story ID for story %d", i)
}
// Initialize story data for test output
storyKey := fmt.Sprintf("%s_%s", storyData.Username, storyData.Title)
storyDataForOutput := TestStoryData{
ID: storyID,
Username: storyData.Username,
Title: storyData.Title,
Status: storyData.Status,
Sections: []TestStorySectionData{},
}
// Create sections for this story
for j, sectionData := range storyData.Sections {
section := &models.StorySection{
StoryID: uint(storyID),
SectionNumber: sectionData.SectionNumber,
Content: sectionData.Content,
LanguageLevel: sectionData.LanguageLevel,
WordCount: sectionData.WordCount,
GeneratedBy: models.GeneratorType(sectionData.GeneratedBy),
GeneratedAt: time.Now(),
GenerationDate: time.Now(),
}
// Insert section
_, err := db.Exec(`
INSERT INTO story_sections (story_id, section_number, content, language_level, word_count,
generated_by, generated_at, generation_date)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, section.StoryID, section.SectionNumber, section.Content, section.LanguageLevel,
section.WordCount, string(section.GeneratedBy), section.GeneratedAt, section.GenerationDate)
if err != nil {
return stories, contextutils.WrapErrorf(err, "failed to insert section %d for story %d", j, i)
}
// Get the section ID
var sectionID int
err = db.QueryRow(`
SELECT id FROM story_sections WHERE story_id = $1 AND section_number = $2
`, section.StoryID, section.SectionNumber).Scan(§ionID)
if err != nil {
return stories, contextutils.WrapErrorf(err, "failed to get section ID for section %d of story %d", j, i)
}
// Add section data to story data for test output
sectionDataForOutput := TestStorySectionData{
ID: sectionID,
StoryID: storyID,
SectionNumber: section.SectionNumber,
Content: section.Content,
LanguageLevel: section.LanguageLevel,
WordCount: section.WordCount,
GeneratedBy: string(section.GeneratedBy),
}
storyDataForOutput.Sections = append(storyDataForOutput.Sections, sectionDataForOutput)
// Create questions for this section
for k, questionData := range sectionData.Questions {
question := &models.StorySectionQuestion{
SectionID: uint(sectionID),
QuestionText: questionData.QuestionText,
Options: questionData.Options,
CorrectAnswerIndex: questionData.CorrectAnswerIndex,
Explanation: questionData.Explanation,
CreatedAt: time.Now(),
}
// Convert options to JSON for database storage
optionsJSON, err := json.Marshal(question.Options)
if err != nil {
return stories, contextutils.WrapErrorf(err, "failed to marshal options for question %d for section %d of story %d", k, j, i)
}
// Insert question
_, err = db.Exec(`
INSERT INTO story_section_questions (section_id, question_text, options, correct_answer_index, explanation, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
`, question.SectionID, question.QuestionText, optionsJSON, question.CorrectAnswerIndex,
question.Explanation, question.CreatedAt)
if err != nil {
return stories, contextutils.WrapErrorf(err, "failed to insert question %d for section %d of story %d", k, j, i)
}
}
}
// Store story data for test output after all sections are created
stories[storyKey] = storyDataForOutput
logger.Info(ctx, "Created test story", map[string]interface{}{
"username": storyData.Username,
"title": storyData.Title,
"story_id": storyID,
})
}
return stories, nil
}
// loadAndCreateSnippets loads and creates snippets from test data
func loadAndCreateSnippets(ctx context.Context, filePath string, users map[string]*models.User, db *sql.DB, logger *observability.Logger) (map[string]TestSnippetData, error) {
snippets := make(map[string]TestSnippetData)
data, err := os.ReadFile(filePath)
if err != nil {
// Snippets file is optional, so just return if it doesn't exist
logger.Info(ctx, "Snippets file not found, skipping", map[string]interface{}{
"file_path": filePath,
})
return snippets, nil
}
var testSnippets TestSnippets
if err := yaml.Unmarshal(data, &testSnippets); err != nil {
return snippets, contextutils.WrapError(err, "failed to parse snippets data")
}
// Create snippets service
snippetsService := services.NewSnippetsService(db, nil, logger)
for i, snippetData := range testSnippets.Snippets {
user, exists := users[snippetData.Username]
if !exists {
return snippets, contextutils.ErrorWithContextf("user not found for snippet: %s", snippetData.Username)
}
// Create snippet request
createReq := api.CreateSnippetRequest{
OriginalText: snippetData.OriginalText,
TranslatedText: snippetData.TranslatedText,
SourceLanguage: snippetData.SourceLanguage,
TargetLanguage: snippetData.TargetLanguage,
Context: snippetData.Context,
}
// Create snippet using the service
snippet, err := snippetsService.CreateSnippet(ctx, int64(user.ID), createReq)
if err != nil {
return snippets, contextutils.WrapErrorf(err, "failed to create snippet %d", i)
}
// Initialize snippet data for test output
snippetKey := fmt.Sprintf("%s_%s_%s", snippetData.Username, snippetData.OriginalText, snippetData.SourceLanguage)
snippets[snippetKey] = TestSnippetData{
ID: int(snippet.ID),
Username: snippetData.Username,
OriginalText: snippet.OriginalText,
TranslatedText: snippet.TranslatedText,
SourceLanguage: snippet.SourceLanguage,
TargetLanguage: snippet.TargetLanguage,
}
logger.Info(ctx, "Created test snippet", map[string]interface{}{
"username": snippetData.Username,
"original_text": snippetData.OriginalText,
"snippet_id": snippet.ID,
})
}
return snippets, nil
}
// outputUserDataForTests outputs the created user data to a JSON file for E2E tests to read
func outputUserDataForTests(users map[string]*models.User, rootDir string, logger *observability.Logger) error {
// Create a simplified structure for the E2E test
type TestUserData struct {
ID int `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
}
userData := make(map[string]TestUserData)
for username, user := range users {
userData[username] = TestUserData{
ID: user.ID,
Username: user.Username,
Email: user.Email.String,
}
}
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-users.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(userData, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal user data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write user data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output user data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"user_count": len(userData),
})
return nil
}
// outputStoryDataForTests outputs the created story data to a JSON file for E2E tests to read
func outputStoryDataForTests(stories map[string]TestStoryData, rootDir string, logger *observability.Logger) error {
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-stories.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(stories, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal stories data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write stories data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output stories data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"stories_count": len(stories),
})
return nil
}
// outputSnippetDataForTests outputs the created snippet data to a JSON file for E2E tests to read
func outputSnippetDataForTests(snippets map[string]TestSnippetData, rootDir string, logger *observability.Logger) error {
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-snippets.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(snippets, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal snippets data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write snippets data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output snippets data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"snippets_count": len(snippets),
})
return nil
}
// outputRolesDataForTests outputs the created roles data to a JSON file for E2E tests to read
func outputRolesDataForTests(db *sql.DB, rootDir string, logger *observability.Logger) error {
// Query all roles from the database
rows, err := db.Query(`
SELECT id, name, description, created_at, updated_at
FROM roles
ORDER BY id
`)
if err != nil {
return contextutils.WrapErrorf(err, "failed to query roles from database")
}
defer func() {
if err := rows.Close(); err != nil {
logger.Warn(context.Background(), "Warning: failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
// Create a simplified structure for the E2E test
type TestRoleData struct {
ID int `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
}
roleData := make(map[string]TestRoleData)
for rows.Next() {
var role models.Role
err := rows.Scan(&role.ID, &role.Name, &role.Description, &role.CreatedAt, &role.UpdatedAt)
if err != nil {
return contextutils.WrapErrorf(err, "failed to scan role data")
}
roleData[role.Name] = TestRoleData{
ID: role.ID,
Name: role.Name,
Description: role.Description,
}
}
if err := rows.Err(); err != nil {
return contextutils.WrapErrorf(err, "error iterating over roles")
}
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-roles.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(roleData, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal roles data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write roles data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output roles data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"roles_count": len(roleData),
})
return nil
}
func loadAndCreateConversations(ctx context.Context, filePath string, users map[string]*models.User, db *sql.DB, logger *observability.Logger) (map[string]TestConversationData, error) {
conversations := make(map[string]TestConversationData)
data, err := os.ReadFile(filePath)
if err != nil {
// Conversations file is optional, so just return if it doesn't exist
logger.Info(ctx, "Conversations file not found, skipping", map[string]interface{}{
"file_path": filePath,
})
return conversations, nil
}
var testConversations TestConversations
if err := yaml.Unmarshal(data, &testConversations); err != nil {
return conversations, contextutils.WrapError(err, "failed to parse conversations data")
}
// Create conversation service
conversationService := services.NewConversationService(db)
for i, convData := range testConversations.Conversations {
user, exists := users[convData.Username]
if !exists {
return conversations, contextutils.ErrorWithContextf("user not found for conversation: %s", convData.Username)
}
// Create conversation
createReq := &api.CreateConversationRequest{
Title: convData.Title,
}
conversation, err := conversationService.CreateConversation(ctx, uint(user.ID), createReq)
if err != nil {
return conversations, contextutils.WrapErrorf(err, "failed to create conversation %d", i)
}
// Store conversation data for test output (messages will be added below)
convKey := fmt.Sprintf("%s_%s", convData.Username, convData.Title)
conversations[convKey] = TestConversationData{
ID: conversation.Id.String(),
Username: convData.Username,
Title: convData.Title,
Messages: []TestMessageData{},
}
// Create messages for this conversation
for j, msgData := range convData.Messages {
content := struct {
Text *string `json:"text,omitempty"`
}{
Text: &msgData.Content,
}
createMsgReq := &api.CreateMessageRequest{
Content: content,
Role: api.CreateMessageRequestRole(msgData.Role),
QuestionId: msgData.QuestionID,
}
_, err := conversationService.AddMessage(ctx, conversation.Id.String(), uint(user.ID), createMsgReq)
if err != nil {
return conversations, contextutils.WrapErrorf(err, "failed to add message %d for conversation %d", j, i)
}
}
// Now retrieve all messages for this conversation to get their actual data
messages, err := conversationService.GetConversationMessages(ctx, conversation.Id.String(), uint(user.ID))
if err != nil {
return conversations, contextutils.WrapErrorf(err, "failed to get messages for conversation %d", i)
}
// Convert messages to our test data format
var testMessages []TestMessageData
for _, msg := range messages {
testMsg := TestMessageData{
ID: msg.Id.String(),
ConversationID: msg.ConversationId.String(),
Role: string(msg.Role),
Bookmarked: false, // Default value
CreatedAt: msg.CreatedAt.Format(time.RFC3339),
UpdatedAt: msg.UpdatedAt.Format(time.RFC3339),
}
if msg.QuestionId != nil {
testMsg.QuestionID = msg.QuestionId
}
if msg.Content.Text != nil {
testMsg.Content = *msg.Content.Text
}
testMessages = append(testMessages, testMsg)
}
// Update the conversation with the actual messages
conversations[convKey] = TestConversationData{
ID: conversation.Id.String(),
Username: convData.Username,
Title: convData.Title,
Messages: testMessages,
}
logger.Info(ctx, "Created test conversation", map[string]interface{}{
"username": convData.Username,
"title": convData.Title,
"conversation_id": conversation.Id,
})
}
return conversations, nil
}
// outputConversationDataForTests outputs the created conversation data to a JSON file for E2E tests to read
func outputConversationDataForTests(conversations map[string]TestConversationData, rootDir string, logger *observability.Logger) error {
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-conversations.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(conversations, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal conversations data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write conversations data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output conversations data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"conversations_count": len(conversations),
})
return nil
}
// loadAndCreateFeedback loads and creates feedback reports from test data
func loadAndCreateFeedback(ctx context.Context, filePath string, users map[string]*models.User, db *sql.DB, logger *observability.Logger) (map[string]TestFeedbackData, error) {
feedback := make(map[string]TestFeedbackData)
data, err := os.ReadFile(filePath)
if err != nil {
// Feedback file is optional, so just return if it doesn't exist
logger.Info(ctx, "Feedback file not found, skipping", map[string]interface{}{
"file_path": filePath,
})
return feedback, nil
}
var testFeedback TestFeedback
if err := yaml.Unmarshal(data, &testFeedback); err != nil {
return feedback, contextutils.WrapError(err, "failed to parse feedback data")
}
for i, feedbackData := range testFeedback.FeedbackReports {
user, exists := users[feedbackData.Username]
if !exists {
return feedback, contextutils.ErrorWithContextf("user not found for feedback: %s", feedbackData.Username)
}
// Default values
feedbackType := feedbackData.FeedbackType
if feedbackType == "" {
feedbackType = "general"
}
status := feedbackData.Status
if status == "" {
status = "new"
}
// Marshal context_data to JSON
contextJSON, err := json.Marshal(feedbackData.ContextData)
if err != nil {
return feedback, contextutils.WrapErrorf(err, "failed to marshal context_data for feedback %d", i)
}
// Insert feedback directly into database
var feedbackID int
err = db.QueryRow(`
INSERT INTO feedback_reports (user_id, feedback_text, feedback_type, context_data, status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())
RETURNING id
`, user.ID, feedbackData.FeedbackText, feedbackType, contextJSON, status).Scan(&feedbackID)
if err != nil {
return feedback, contextutils.WrapErrorf(err, "failed to insert feedback %d", i)
}
// Store feedback data for test output
feedbackKey := fmt.Sprintf("%s_%d", feedbackData.Username, i)
feedback[feedbackKey] = TestFeedbackData{
ID: feedbackID,
Username: feedbackData.Username,
FeedbackText: feedbackData.FeedbackText,
FeedbackType: feedbackType,
Status: status,
ContextData: feedbackData.ContextData,
}
logger.Info(ctx, "Created test feedback", map[string]interface{}{
"username": feedbackData.Username,
"feedback_id": feedbackID,
"status": status,
"feedback_type": feedbackType,
})
}
return feedback, nil
}
// outputFeedbackDataForTests outputs the created feedback data to a JSON file for E2E tests to read
func outputFeedbackDataForTests(feedback map[string]TestFeedbackData, rootDir string, logger *observability.Logger) error {
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-feedback.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(feedback, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal feedback data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write feedback data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output feedback data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"feedback_count": len(feedback),
})
return nil
}
// Package main provides the entry point for the Quiz Application worker service.
package main
import (
"context"
"io/fs"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/handlers"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/version"
"quizapp/internal/worker"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
)
// fatalIfErr logs the error with context and panics with a consistent message
func fatalIfErr(ctx context.Context, logger *observability.Logger, msg string, err error, fields map[string]interface{}) {
logger.Error(ctx, msg, err, fields)
panic(msg + ": " + err.Error())
}
func main() {
ctx := context.Background()
// Load configuration
cfg, err := config.NewConfig()
if err != nil {
panic("Failed to load configuration: " + err.Error())
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-worker")
if err != nil {
panic("Failed to initialize observability: " + err.Error())
}
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
logger.Info(ctx, "Starting quiz worker service", map[string]interface{}{
"port": cfg.Server.WorkerPort,
"logLevel": cfg.Server.LogLevel,
"debug": cfg.Server.Debug,
})
// Initialize database manager with logger
dbManager := database.NewManager(logger)
// Initialize database connection without running migrations (migrations are managed elsewhere)
db, err := dbManager.InitDBWithoutMigrations(cfg.Database)
if err != nil {
fatalIfErr(ctx, logger, "Failed to initialize database", err, map[string]interface{}{"db_url": cfg.Database.URL})
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
learningService := services.NewLearningServiceWithLogger(db, cfg, logger)
// Create question service
questionService := services.NewQuestionServiceWithLogger(db, learningService, cfg, logger)
// Create usage stats service
usageStatsService := services.NewUsageStatsService(cfg, db, logger)
aiService := services.NewAIService(cfg, logger, usageStatsService)
workerService := services.NewWorkerServiceWithLogger(db, logger)
generationHintService := services.NewGenerationHintService(db, logger)
emailService := services.CreateEmailServiceWithDB(cfg, logger, db)
// Create daily question service
dailyQuestionService := services.NewDailyQuestionService(db, logger, questionService, learningService)
// Create word of the day service
wordOfTheDayService := services.NewWordOfTheDayService(db, logger)
// Create story service
storyService := services.NewStoryService(db, cfg, logger)
// Create translation cache repository
translationCacheRepo := services.NewTranslationCacheRepository(db, logger)
// Initialize worker with the observability logger
workerInstance := worker.NewWorker(userService, questionService, aiService, learningService, workerService, dailyQuestionService, wordOfTheDayService, storyService, emailService, generationHintService, translationCacheRepo, "default", cfg, logger)
go workerInstance.Start(ctx)
// Initialize admin handler for worker UI
adminHandler := handlers.NewWorkerAdminHandlerWithLogger(userService, questionService, aiService, cfg, workerInstance, workerService, learningService, dailyQuestionService, logger)
// Setup Gin router
gin.SetMode(gin.ReleaseMode)
if cfg.Server.Debug {
gin.SetMode(gin.DebugMode)
}
router := gin.New()
router.Use(gin.Recovery())
// Add HTTP request logging middleware using our observability logger
router.Use(func(c *gin.Context) {
start := time.Now()
// Process request
c.Next()
// Log request details using our observability logger
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Create structured log entry
fields := map[string]interface{}{
"http.method": method,
"http.path": path,
"http.status_code": statusCode,
"http.latency_ms": latency.Milliseconds(),
"http.client_ip": clientIP,
"http.user_agent": c.Request.UserAgent(),
}
// Add error message if present
if len(c.Errors) > 0 {
fields["http.error"] = c.Errors.String()
}
// Log using our observability logger (goes to both stdout and OTLP)
// Use appropriate log level based on status code
if statusCode >= 500 {
logger.Error(c.Request.Context(), "HTTP request failed", nil, fields)
} else if statusCode >= 400 {
logger.Warn(c.Request.Context(), "HTTP request warning", fields)
} else {
logger.Info(c.Request.Context(), "HTTP request", fields)
}
})
// Add OpenTelemetry middleware for HTTP tracing with automatic error attributes
router.Use(observability.GinMiddlewareWithErrorHandling("quiz-worker"))
// Add CORS middleware
router.Use(func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
// Setup session middleware
store := cookie.NewStore([]byte(cfg.Server.SessionSecret))
router.Use(sessions.Sessions(config.SessionName, store))
// Setup routes
v1 := router.Group("/v1")
{
// Health check route
v1.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Version route
v1.GET("/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"service": "worker",
"version": version.Version,
"commit": version.Commit,
"buildTime": version.BuildTime,
})
})
}
// Serve static assets (CSS/JS) for worker admin dashboard
staticFS, _ := fs.Sub(handlers.AssetsFS, "templates/assets")
router.StaticFS("/worker", http.FS(staticFS))
// Config dump endpoint
router.GET("/configz", adminHandler.GetConfigz)
// API routes for worker management
api := router.Group("/v1")
{
// Admin worker endpoints (for frontend)
adminWorker := api.Group("/admin/worker")
adminWorker.Use(middleware.RequireAuth())
{
adminWorker.GET("/details", adminHandler.GetWorkerDetails)
adminWorker.GET("/status", adminHandler.GetWorkerStatus)
adminWorker.GET("/logs", adminHandler.GetActivityLogs)
adminWorker.POST("/pause", adminHandler.PauseWorker)
adminWorker.POST("/resume", adminHandler.ResumeWorker)
adminWorker.POST("/trigger", adminHandler.TriggerWorkerRun)
adminWorker.GET("/ai-concurrency", adminHandler.GetAIConcurrencyStats)
}
// Worker user control endpoints (for pausing/resuming user question generation)
workerUsers := api.Group("/admin/worker/users")
workerUsers.Use(middleware.RequireAuth())
{
workerUsers.GET("/", adminHandler.GetWorkerUsers)
workerUsers.POST("/pause", adminHandler.PauseWorkerUser)
workerUsers.POST("/resume", adminHandler.ResumeWorkerUser)
}
// System health for worker
system := api.Group("/system")
{
system.GET("/health", adminHandler.GetSystemHealth)
}
// Admin analytics endpoints (for frontend)
adminAnalytics := api.Group("/admin/worker/analytics")
adminAnalytics.Use(middleware.RequireAuth())
{
adminAnalytics.GET("/priority-scores", adminHandler.GetPriorityAnalytics)
adminAnalytics.GET("/user-performance", adminHandler.GetUserPerformanceAnalytics)
adminAnalytics.GET("/generation-intelligence", adminHandler.GetGenerationIntelligence)
adminAnalytics.GET("/system-health", adminHandler.GetSystemHealthAnalytics)
adminAnalytics.GET("/comparison", adminHandler.GetUserComparisonAnalytics)
adminAnalytics.GET("/user/:userID", adminHandler.GetUserPriorityAnalytics)
}
// Admin daily questions endpoints (for frontend)
adminDaily := api.Group("/admin/worker/daily")
adminDaily.Use(middleware.RequireAuth())
{
adminDaily.GET("/users/:userId/questions/:date", adminHandler.GetUserDailyQuestions)
adminDaily.POST("/users/:userId/questions/:date/regenerate", adminHandler.RegenerateUserDailyQuestions)
}
// Admin notification endpoints (for frontend)
adminNotifications := api.Group("/admin/worker/notifications")
adminNotifications.Use(middleware.RequireAuth())
{
adminNotifications.GET("/stats", adminHandler.GetNotificationStats)
adminNotifications.GET("/errors", adminHandler.GetNotificationErrors)
adminNotifications.GET("/sent", adminHandler.GetSentNotifications)
adminNotifications.POST("/test/create-sent", adminHandler.CreateTestSentNotification)
adminNotifications.POST("/force-send", adminHandler.ForceSendNotification)
}
}
// Automatic route listing at root path
routeListing := handlers.NewRouteListingHandler("Worker")
routeListing.CollectRoutes(router)
// Root path shows all available routes
router.GET("/", func(c *gin.Context) {
// Support JSON output via query parameter
if c.Query("json") == "true" {
routeListing.GetRouteListingJSON(c)
} else {
routeListing.GetRouteListingPage(c)
}
})
// Create HTTP server
srv := &http.Server{
Addr: ":" + cfg.Server.WorkerPort,
Handler: router,
}
// Start server in a goroutine
go func() {
logger.Info(ctx, "Worker server starting", map[string]interface{}{"port": cfg.Server.WorkerPort})
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fatalIfErr(ctx, logger, "Failed to start worker server", err, map[string]interface{}{"port": cfg.Server.WorkerPort})
}
}()
// Wait for interrupt signal to gracefully shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info(ctx, "Worker server shutting down", map[string]interface{}{"service": "worker"})
// Graceful shutdown with timeout
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, config.WorkerShutdownTimeout)
defer shutdownCancel()
// Shutdown the worker first
if err := workerInstance.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "Warning: failed to shutdown worker", map[string]interface{}{"error": err.Error(), "service": "worker"})
}
// Then shutdown the server
if err := srv.Shutdown(shutdownCtx); err != nil {
fatalIfErr(ctx, logger, "Worker server forced to shutdown", err, map[string]interface{}{"service": "worker"})
}
logger.Info(ctx, "Worker server exited", map[string]interface{}{"service": "worker"})
}
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.0 DO NOT EDIT.
package api
import (
"encoding/json"
"fmt"
"time"
"github.com/oapi-codegen/runtime"
openapi_types "github.com/oapi-codegen/runtime/types"
)
const (
ApiKeyQueryScopes = "apiKeyQuery.Scopes"
BearerAuthScopes = "bearerAuth.Scopes"
CookieAuthScopes = "cookieAuth.Scopes"
SessionAuthScopes = "sessionAuth.Scopes"
)
// Defines values for APIKeySummaryPermissionLevel.
const (
APIKeySummaryPermissionLevelFull APIKeySummaryPermissionLevel = "full"
APIKeySummaryPermissionLevelReadonly APIKeySummaryPermissionLevel = "readonly"
)
// Defines values for APIKeyTestResponsePermissionLevel.
const (
APIKeyTestResponsePermissionLevelFull APIKeyTestResponsePermissionLevel = "full"
APIKeyTestResponsePermissionLevelReadonly APIKeyTestResponsePermissionLevel = "readonly"
)
// Defines values for ChatMessageRole.
const (
ChatMessageRoleAssistant ChatMessageRole = "assistant"
ChatMessageRoleUser ChatMessageRole = "user"
)
// Defines values for CreateAPIKeyRequestPermissionLevel.
const (
CreateAPIKeyRequestPermissionLevelFull CreateAPIKeyRequestPermissionLevel = "full"
CreateAPIKeyRequestPermissionLevelReadonly CreateAPIKeyRequestPermissionLevel = "readonly"
)
// Defines values for CreateAPIKeyResponsePermissionLevel.
const (
CreateAPIKeyResponsePermissionLevelFull CreateAPIKeyResponsePermissionLevel = "full"
CreateAPIKeyResponsePermissionLevelReadonly CreateAPIKeyResponsePermissionLevel = "readonly"
)
// Defines values for CreateMessageRequestRole.
const (
CreateMessageRequestRoleAssistant CreateMessageRequestRole = "assistant"
CreateMessageRequestRoleUser CreateMessageRequestRole = "user"
)
// Defines values for CreateStoryRequestSectionLengthOverride.
const (
CreateStoryRequestSectionLengthOverrideLong CreateStoryRequestSectionLengthOverride = "long"
CreateStoryRequestSectionLengthOverrideMedium CreateStoryRequestSectionLengthOverride = "medium"
CreateStoryRequestSectionLengthOverrideShort CreateStoryRequestSectionLengthOverride = "short"
)
// Defines values for ErrorResponseSeverity.
const (
ErrorResponseSeverityError ErrorResponseSeverity = "error"
ErrorResponseSeverityFatal ErrorResponseSeverity = "fatal"
ErrorResponseSeverityInfo ErrorResponseSeverity = "info"
ErrorResponseSeverityWarn ErrorResponseSeverity = "warn"
)
// Defines values for FeedbackReportFeedbackType.
const (
FeedbackReportFeedbackTypeBug FeedbackReportFeedbackType = "bug"
FeedbackReportFeedbackTypeFeatureRequest FeedbackReportFeedbackType = "feature_request"
FeedbackReportFeedbackTypeGeneral FeedbackReportFeedbackType = "general"
FeedbackReportFeedbackTypeImprovement FeedbackReportFeedbackType = "improvement"
)
// Defines values for FeedbackReportStatus.
const (
FeedbackReportStatusDismissed FeedbackReportStatus = "dismissed"
FeedbackReportStatusInProgress FeedbackReportStatus = "in_progress"
FeedbackReportStatusNew FeedbackReportStatus = "new"
FeedbackReportStatusResolved FeedbackReportStatus = "resolved"
)
// Defines values for FeedbackSubmissionRequestFeedbackType.
const (
FeedbackSubmissionRequestFeedbackTypeBug FeedbackSubmissionRequestFeedbackType = "bug"
FeedbackSubmissionRequestFeedbackTypeFeatureRequest FeedbackSubmissionRequestFeedbackType = "feature_request"
FeedbackSubmissionRequestFeedbackTypeGeneral FeedbackSubmissionRequestFeedbackType = "general"
FeedbackSubmissionRequestFeedbackTypeImprovement FeedbackSubmissionRequestFeedbackType = "improvement"
)
// Defines values for FeedbackUpdateRequestStatus.
const (
FeedbackUpdateRequestStatusDismissed FeedbackUpdateRequestStatus = "dismissed"
FeedbackUpdateRequestStatusInProgress FeedbackUpdateRequestStatus = "in_progress"
FeedbackUpdateRequestStatusNew FeedbackUpdateRequestStatus = "new"
FeedbackUpdateRequestStatusResolved FeedbackUpdateRequestStatus = "resolved"
)
// Defines values for NotificationErrorErrorType.
const (
NotificationErrorErrorTypeEmailDisabled NotificationErrorErrorType = "email_disabled"
NotificationErrorErrorTypeOther NotificationErrorErrorType = "other"
NotificationErrorErrorTypeSmtpError NotificationErrorErrorType = "smtp_error"
NotificationErrorErrorTypeTemplateError NotificationErrorErrorType = "template_error"
NotificationErrorErrorTypeUserNotFound NotificationErrorErrorType = "user_not_found"
)
// Defines values for NotificationErrorNotificationType.
const (
NotificationErrorNotificationTypeDailyReminder NotificationErrorNotificationType = "daily_reminder"
NotificationErrorNotificationTypeTestEmail NotificationErrorNotificationType = "test_email"
)
// Defines values for QuestionStatus.
const (
QuestionStatusActive QuestionStatus = "active"
QuestionStatusReported QuestionStatus = "reported"
)
// Defines values for QuestionType.
const (
FillBlank QuestionType = "fill_blank"
Qa QuestionType = "qa"
ReadingComprehension QuestionType = "reading_comprehension"
Vocabulary QuestionType = "vocabulary"
)
// Defines values for SentNotificationNotificationType.
const (
SentNotificationNotificationTypeDailyReminder SentNotificationNotificationType = "daily_reminder"
SentNotificationNotificationTypeTestEmail SentNotificationNotificationType = "test_email"
)
// Defines values for SentNotificationStatus.
const (
SentNotificationStatusBounced SentNotificationStatus = "bounced"
SentNotificationStatusFailed SentNotificationStatus = "failed"
SentNotificationStatusSent SentNotificationStatus = "sent"
)
// Defines values for StorySectionLengthOverride.
const (
StorySectionLengthOverrideLong StorySectionLengthOverride = "long"
StorySectionLengthOverrideMedium StorySectionLengthOverride = "medium"
StorySectionLengthOverrideShort StorySectionLengthOverride = "short"
)
// Defines values for StoryStatus.
const (
StoryStatusActive StoryStatus = "active"
StoryStatusArchived StoryStatus = "archived"
StoryStatusCompleted StoryStatus = "completed"
)
// Defines values for StoryWithSectionsSectionLengthOverride.
const (
Long StoryWithSectionsSectionLengthOverride = "long"
Medium StoryWithSectionsSectionLengthOverride = "medium"
Short StoryWithSectionsSectionLengthOverride = "short"
)
// Defines values for StoryWithSectionsStatus.
const (
Active StoryWithSectionsStatus = "active"
Archived StoryWithSectionsStatus = "archived"
Completed StoryWithSectionsStatus = "completed"
)
// Defines values for TTSRequestStreamFormat.
const (
Audio TTSRequestStreamFormat = "audio"
AudioStream TTSRequestStreamFormat = "audio_stream"
Sse TTSRequestStreamFormat = "sse"
)
// Defines values for TTSResponseType.
const (
TTSResponseTypeAudio TTSResponseType = "audio"
TTSResponseTypeError TTSResponseType = "error"
TTSResponseTypeUsage TTSResponseType = "usage"
)
// Defines values for WordOfTheDayDisplaySourceType.
const (
WordOfTheDayDisplaySourceTypeSnippet WordOfTheDayDisplaySourceType = "snippet"
WordOfTheDayDisplaySourceTypeVocabularyQuestion WordOfTheDayDisplaySourceType = "vocabulary_question"
)
// Defines values for WorkerStatusStatus.
const (
Busy WorkerStatusStatus = "busy"
Error WorkerStatusStatus = "error"
Idle WorkerStatusStatus = "idle"
)
// Defines values for DeleteV1AdminBackendFeedbackParamsStatus.
const (
DeleteV1AdminBackendFeedbackParamsStatusDismissed DeleteV1AdminBackendFeedbackParamsStatus = "dismissed"
DeleteV1AdminBackendFeedbackParamsStatusInProgress DeleteV1AdminBackendFeedbackParamsStatus = "in_progress"
DeleteV1AdminBackendFeedbackParamsStatusNew DeleteV1AdminBackendFeedbackParamsStatus = "new"
DeleteV1AdminBackendFeedbackParamsStatusResolved DeleteV1AdminBackendFeedbackParamsStatus = "resolved"
)
// Defines values for GetV1AdminBackendFeedbackParamsStatus.
const (
GetV1AdminBackendFeedbackParamsStatusDismissed GetV1AdminBackendFeedbackParamsStatus = "dismissed"
GetV1AdminBackendFeedbackParamsStatusInProgress GetV1AdminBackendFeedbackParamsStatus = "in_progress"
GetV1AdminBackendFeedbackParamsStatusNew GetV1AdminBackendFeedbackParamsStatus = "new"
GetV1AdminBackendFeedbackParamsStatusResolved GetV1AdminBackendFeedbackParamsStatus = "resolved"
)
// Defines values for GetV1AdminBackendUserzPaginatedParamsAiEnabled.
const (
GetV1AdminBackendUserzPaginatedParamsAiEnabledFalse GetV1AdminBackendUserzPaginatedParamsAiEnabled = "false"
GetV1AdminBackendUserzPaginatedParamsAiEnabledTrue GetV1AdminBackendUserzPaginatedParamsAiEnabled = "true"
)
// Defines values for GetV1AdminBackendUserzPaginatedParamsActive.
const (
GetV1AdminBackendUserzPaginatedParamsActiveFalse GetV1AdminBackendUserzPaginatedParamsActive = "false"
GetV1AdminBackendUserzPaginatedParamsActiveTrue GetV1AdminBackendUserzPaginatedParamsActive = "true"
)
// Defines values for GetV1AdminWorkerNotificationsErrorsParamsErrorType.
const (
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeEmailDisabled GetV1AdminWorkerNotificationsErrorsParamsErrorType = "email_disabled"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeOther GetV1AdminWorkerNotificationsErrorsParamsErrorType = "other"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeSmtpError GetV1AdminWorkerNotificationsErrorsParamsErrorType = "smtp_error"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeTemplateError GetV1AdminWorkerNotificationsErrorsParamsErrorType = "template_error"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeUserNotFound GetV1AdminWorkerNotificationsErrorsParamsErrorType = "user_not_found"
)
// Defines values for GetV1AdminWorkerNotificationsErrorsParamsNotificationType.
const (
GetV1AdminWorkerNotificationsErrorsParamsNotificationTypeDailyReminder GetV1AdminWorkerNotificationsErrorsParamsNotificationType = "daily_reminder"
GetV1AdminWorkerNotificationsErrorsParamsNotificationTypeTestEmail GetV1AdminWorkerNotificationsErrorsParamsNotificationType = "test_email"
)
// Defines values for GetV1AdminWorkerNotificationsErrorsParamsResolved.
const (
False GetV1AdminWorkerNotificationsErrorsParamsResolved = "false"
True GetV1AdminWorkerNotificationsErrorsParamsResolved = "true"
)
// Defines values for GetV1AdminWorkerNotificationsSentParamsNotificationType.
const (
GetV1AdminWorkerNotificationsSentParamsNotificationTypeDailyReminder GetV1AdminWorkerNotificationsSentParamsNotificationType = "daily_reminder"
GetV1AdminWorkerNotificationsSentParamsNotificationTypeTestEmail GetV1AdminWorkerNotificationsSentParamsNotificationType = "test_email"
)
// Defines values for GetV1AdminWorkerNotificationsSentParamsStatus.
const (
GetV1AdminWorkerNotificationsSentParamsStatusBounced GetV1AdminWorkerNotificationsSentParamsStatus = "bounced"
GetV1AdminWorkerNotificationsSentParamsStatusFailed GetV1AdminWorkerNotificationsSentParamsStatus = "failed"
GetV1AdminWorkerNotificationsSentParamsStatusSent GetV1AdminWorkerNotificationsSentParamsStatus = "sent"
)
// Defines values for GetV1SnippetsParamsLevel.
const (
A1 GetV1SnippetsParamsLevel = "A1"
A2 GetV1SnippetsParamsLevel = "A2"
B1 GetV1SnippetsParamsLevel = "B1"
B2 GetV1SnippetsParamsLevel = "B2"
C1 GetV1SnippetsParamsLevel = "C1"
C2 GetV1SnippetsParamsLevel = "C2"
)
// AIConcurrencyStats defines model for AIConcurrencyStats.
type AIConcurrencyStats struct {
ActiveRequests *int `json:"active_requests,omitempty"`
MaxConcurrent *int `json:"max_concurrent,omitempty"`
MaxPerUser *int `json:"max_per_user,omitempty"`
QueuedRequests *int `json:"queued_requests,omitempty"`
TotalRequests *int `json:"total_requests,omitempty"`
UserActiveCount *map[string]int `json:"user_active_count,omitempty"`
}
// AIProviders defines model for AIProviders.
type AIProviders struct {
Levels *[]string `json:"levels,omitempty"`
Providers *[]struct {
Code *string `json:"code,omitempty"`
Models *[]struct {
Code *string `json:"code,omitempty"`
Name *string `json:"name,omitempty"`
} `json:"models,omitempty"`
Name *string `json:"name,omitempty"`
Url *string `json:"url,omitempty"`
// UsageSupported Whether the provider supports usage tracking in streaming responses
UsageSupported *bool `json:"usage_supported,omitempty"`
} `json:"providers,omitempty"`
}
// APIKeyAvailabilityResponse defines model for APIKeyAvailabilityResponse.
type APIKeyAvailabilityResponse struct {
// HasApiKey Whether the user has a saved API key for this provider
HasApiKey bool `json:"has_api_key"`
}
// APIKeySummary defines model for APIKeySummary.
type APIKeySummary struct {
// CreatedAt Creation timestamp
CreatedAt *time.Time `json:"created_at,omitempty"`
// Id Unique ID
Id *int `json:"id,omitempty"`
// KeyName Name of the key
KeyName *string `json:"key_name,omitempty"`
// KeyPrefix First characters for identification
KeyPrefix *string `json:"key_prefix,omitempty"`
// LastUsedAt Last time this key was used
LastUsedAt *time.Time `json:"last_used_at"`
// PermissionLevel Permission level
PermissionLevel *APIKeySummaryPermissionLevel `json:"permission_level,omitempty"`
// UpdatedAt Last update timestamp
UpdatedAt *time.Time `json:"updated_at,omitempty"`
}
// APIKeySummaryPermissionLevel Permission level
type APIKeySummaryPermissionLevel string
// APIKeyTestResponse defines model for APIKeyTestResponse.
type APIKeyTestResponse struct {
ApiKeyId *int `json:"api_key_id,omitempty"`
Method *string `json:"method,omitempty"`
Ok *bool `json:"ok,omitempty"`
PermissionLevel *APIKeyTestResponsePermissionLevel `json:"permission_level,omitempty"`
UserId *int `json:"user_id,omitempty"`
Username *string `json:"username,omitempty"`
}
// APIKeyTestResponsePermissionLevel defines model for APIKeyTestResponse.PermissionLevel.
type APIKeyTestResponsePermissionLevel string
// APIKeysListResponse defines model for APIKeysListResponse.
type APIKeysListResponse struct {
ApiKeys *[]APIKeySummary `json:"api_keys,omitempty"`
// Count Total number of keys
Count *int `json:"count,omitempty"`
}
// AggregatedVersion defines model for AggregatedVersion.
type AggregatedVersion struct {
Backend ServiceVersion `json:"backend"`
Worker AggregatedVersion_Worker `json:"worker"`
}
// AggregatedVersionWorker1 defines model for .
type AggregatedVersionWorker1 struct {
// Error Error message when worker is unavailable
Error string `json:"error"`
}
// AggregatedVersion_Worker defines model for AggregatedVersion.Worker.
type AggregatedVersion_Worker struct {
union json.RawMessage
}
// AnswerRequest defines model for AnswerRequest.
type AnswerRequest struct {
// QuestionId ID of the question being answered
QuestionId int64 `json:"question_id"`
// ResponseTimeMs Response time in milliseconds (0-5 minutes)
ResponseTimeMs *int32 `json:"response_time_ms,omitempty"`
// UserAnswerIndex Index of the user's selected answer in the original options array (0-based)
UserAnswerIndex int `json:"user_answer_index"`
}
// AnswerResponse defines model for AnswerResponse.
type AnswerResponse struct {
// CorrectAnswerIndex Index of the correct answer in the options array (0-based)
CorrectAnswerIndex *int `json:"correct_answer_index,omitempty"`
Explanation *string `json:"explanation,omitempty"`
IsCorrect *bool `json:"is_correct,omitempty"`
NextDifficulty *string `json:"next_difficulty,omitempty"`
// UserAnswer The answer selected by the user
UserAnswer *string `json:"user_answer,omitempty"`
// UserAnswerIndex Index of the user's selected answer in the original options array (0-based)
UserAnswerIndex *int `json:"user_answer_index,omitempty"`
}
// AuthStatusResponse defines model for AuthStatusResponse.
type AuthStatusResponse struct {
// Authenticated Whether the user is currently authenticated
Authenticated bool `json:"authenticated"`
User User `json:"user"`
}
// ChatMessage defines model for ChatMessage.
type ChatMessage struct {
// Bookmarked Whether this message is bookmarked
Bookmarked *bool `json:"bookmarked,omitempty"`
// Content Message content
Content struct {
// Text The actual message text
Text *string `json:"text,omitempty"`
} `json:"content"`
// ConversationId ID of the conversation this message belongs to
ConversationId openapi_types.UUID `json:"conversation_id"`
// ConversationTitle Title of the conversation (optional, included in search results)
ConversationTitle *string `json:"conversation_title,omitempty"`
// CreatedAt When the message was created
CreatedAt time.Time `json:"created_at"`
// Id Message UUID
Id openapi_types.UUID `json:"id"`
// QuestionId Optional question ID if this message relates to a specific question
QuestionId *int `json:"question_id,omitempty"`
// Role Role of the message sender
Role ChatMessageRole `json:"role"`
// UpdatedAt When the message was last updated
UpdatedAt time.Time `json:"updated_at"`
}
// ChatMessageRole Role of the message sender
type ChatMessageRole string
// Conversation defines model for Conversation.
type Conversation struct {
// CreatedAt When the conversation was created
CreatedAt time.Time `json:"created_at"`
// Id Conversation UUID
Id openapi_types.UUID `json:"id"`
// MessageCount Total number of messages in this conversation
MessageCount *int `json:"message_count,omitempty"`
// Messages Array of messages in this conversation (optional, only included when requested)
Messages *[]ChatMessage `json:"messages,omitempty"`
// Title Conversation title
Title string `json:"title"`
// UpdatedAt When the conversation was last updated
UpdatedAt time.Time `json:"updated_at"`
// UserId ID of the user who owns this conversation
UserId int `json:"user_id"`
}
// CreateAPIKeyRequest defines model for CreateAPIKeyRequest.
type CreateAPIKeyRequest struct {
// KeyName A descriptive name for the API key
KeyName string `json:"key_name"`
// PermissionLevel Permission level: 'readonly' for GET requests only, 'full' for all operations
PermissionLevel CreateAPIKeyRequestPermissionLevel `json:"permission_level"`
}
// CreateAPIKeyRequestPermissionLevel Permission level: 'readonly' for GET requests only, 'full' for all operations
type CreateAPIKeyRequestPermissionLevel string
// CreateAPIKeyResponse defines model for CreateAPIKeyResponse.
type CreateAPIKeyResponse struct {
// CreatedAt Creation timestamp
CreatedAt *time.Time `json:"created_at,omitempty"`
// Id Unique ID of the API key
Id *int `json:"id,omitempty"`
// Key Full API key - only shown once!
Key *string `json:"key,omitempty"`
// KeyName Name of the API key
KeyName *string `json:"key_name,omitempty"`
// KeyPrefix First characters of key for identification
KeyPrefix *string `json:"key_prefix,omitempty"`
// Message Warning message
Message *string `json:"message,omitempty"`
// PermissionLevel Permission level
PermissionLevel *CreateAPIKeyResponsePermissionLevel `json:"permission_level,omitempty"`
}
// CreateAPIKeyResponsePermissionLevel Permission level
type CreateAPIKeyResponsePermissionLevel string
// CreateConversationRequest defines model for CreateConversationRequest.
type CreateConversationRequest struct {
// Title Title for the conversation
Title string `json:"title"`
}
// CreateLinearIssueResponse defines model for CreateLinearIssueResponse.
type CreateLinearIssueResponse struct {
// IssueId The Linear issue ID
IssueId string `json:"issue_id"`
// IssueUrl URL to the created Linear issue
IssueUrl string `json:"issue_url"`
// Title The title of the created Linear issue
Title string `json:"title"`
}
// CreateMessageRequest defines model for CreateMessageRequest.
type CreateMessageRequest struct {
// Content Message content
Content struct {
// Text The actual message text
Text *string `json:"text,omitempty"`
} `json:"content"`
// QuestionId Optional question ID if this message relates to a specific question
QuestionId *int `json:"question_id,omitempty"`
// Role Role of the message sender
Role CreateMessageRequestRole `json:"role"`
}
// CreateMessageRequestRole Role of the message sender
type CreateMessageRequestRole string
// CreateSnippetRequest defines model for CreateSnippetRequest.
type CreateSnippetRequest struct {
// Context Optional user-provided context or notes about this snippet
Context *string `json:"context"`
// OriginalText The original text/word to save
OriginalText string `json:"original_text"`
// QuestionId Optional ID of the question where this text was encountered. If provided, the snippet will inherit the question's difficulty level (A1, A2, B1, B2, C1, C2)
QuestionId *int64 `json:"question_id"`
// SectionId Optional ID of the story section where this text was encountered
SectionId *int64 `json:"section_id"`
// SourceLanguage ISO language code of the source text
SourceLanguage string `json:"source_language"`
// StoryId Optional ID of the story where this text was encountered
StoryId *int64 `json:"story_id"`
// TargetLanguage ISO language code of the target translation
TargetLanguage string `json:"target_language"`
// TranslatedText The translated text
TranslatedText string `json:"translated_text"`
}
// CreateStoryRequest defines model for CreateStoryRequest.
type CreateStoryRequest struct {
AuthorStyle *string `json:"author_style"`
CharacterNames *string `json:"character_names"`
CustomInstructions *string `json:"custom_instructions"`
Genre *string `json:"genre"`
SectionLengthOverride *CreateStoryRequestSectionLengthOverride `json:"section_length_override,omitempty"`
Subject *string `json:"subject"`
TimePeriod *string `json:"time_period"`
Title string `json:"title"`
Tone *string `json:"tone"`
}
// CreateStoryRequestSectionLengthOverride defines model for CreateStoryRequest.SectionLengthOverride.
type CreateStoryRequestSectionLengthOverride string
// DailyProgress defines model for DailyProgress.
type DailyProgress struct {
// Completed Number of completed questions
Completed int `json:"completed"`
// Date Date for the progress report (YYYY-MM-DD)
Date openapi_types.Date `json:"date"`
// Total Total number of questions assigned for the date
Total int `json:"total"`
}
// DailyQuestionHistory defines model for DailyQuestionHistory.
type DailyQuestionHistory struct {
// AssignmentDate RFC3339 timestamp of when the question was assigned in the user's timezone (includes offset)
AssignmentDate string `json:"assignment_date"`
// IsCompleted Whether the question was completed on this date
IsCompleted bool `json:"is_completed"`
// IsCorrect Whether the user's answer was correct (null if not attempted)
IsCorrect *bool `json:"is_correct"`
// SubmittedAt When the user submitted their answer
SubmittedAt *string `json:"submitted_at"`
}
// DailyQuestionWithDetails defines model for DailyQuestionWithDetails.
type DailyQuestionWithDetails struct {
// AssignmentDate Date-only assignment (YYYY-MM-DD) representing the logical calendar day the question was assigned (no timezone offset)
AssignmentDate openapi_types.Date `json:"assignment_date"`
// CompletedAt When the question was completed (if completed)
CompletedAt *string `json:"completed_at"`
// CreatedAt When the assignment was created
CreatedAt string `json:"created_at"`
// Id Daily question assignment ID
Id int64 `json:"id"`
// IsCompleted Whether the question has been completed
IsCompleted bool `json:"is_completed"`
Question Question `json:"question"`
// QuestionId Question ID
QuestionId int64 `json:"question_id"`
// SubmittedAt When the user submitted their answer
SubmittedAt *string `json:"submitted_at"`
// UserAnswerIndex The index of the answer option the user selected (0-based)
UserAnswerIndex *int `json:"user_answer_index"`
// UserCorrectCount Number of times this user answered this question correctly
UserCorrectCount *int64 `json:"user_correct_count,omitempty"`
// UserId User ID
UserId int64 `json:"user_id"`
// UserIncorrectCount Number of times this user answered this question incorrectly
UserIncorrectCount *int64 `json:"user_incorrect_count,omitempty"`
// UserShownCount Number of times this question was shown to this user in Daily view
UserShownCount *int64 `json:"user_shown_count,omitempty"`
// UserTotalResponses Number of times this user answered this question
UserTotalResponses *int64 `json:"user_total_responses,omitempty"`
}
// DashboardResponse defines model for DashboardResponse.
type DashboardResponse struct {
AiConcurrencyStats *AIConcurrencyStats `json:"ai_concurrency_stats,omitempty"`
QuestionStats *QuestionStats `json:"question_stats,omitempty"`
Users *[]DashboardUser `json:"users,omitempty"`
WorkerBaseUrl *string `json:"worker_base_url,omitempty"`
WorkerHealth *WorkerHealth `json:"worker_health,omitempty"`
WorkerPort *string `json:"worker_port,omitempty"`
}
// DashboardUser defines model for DashboardUser.
type DashboardUser struct {
Progress *UserProgress `json:"progress,omitempty"`
QuestionStats *UserQuestionStats `json:"question_stats,omitempty"`
User *UserProfile `json:"user,omitempty"`
}
// DeleteAPIKeyResponse defines model for DeleteAPIKeyResponse.
type DeleteAPIKeyResponse struct {
Message *string `json:"message,omitempty"`
Success *bool `json:"success,omitempty"`
}
// EmptyRequest Empty request body for endpoints that don't require request data
type EmptyRequest = map[string]interface{}
// ErrorResponse defines model for ErrorResponse.
type ErrorResponse struct {
// Code Error code identifying the type of error
Code *string `json:"code,omitempty"`
// Details Additional error details
Details *string `json:"details,omitempty"`
// Error Error message (for backward compatibility)
Error *string `json:"error,omitempty"`
// Message Human-readable error message
Message *string `json:"message,omitempty"`
// Retryable Whether the operation can be retried
Retryable *bool `json:"retryable,omitempty"`
// Severity Severity level of the error
Severity *ErrorResponseSeverity `json:"severity,omitempty"`
}
// ErrorResponseSeverity Severity level of the error
type ErrorResponseSeverity string
// FeedbackListResponse defines model for FeedbackListResponse.
type FeedbackListResponse struct {
// Items List of feedback reports
Items []FeedbackReport `json:"items"`
// Page Current page number
Page int `json:"page"`
// PageSize Number of items per page
PageSize int `json:"page_size"`
// Total Total number of feedback reports matching filters
Total int `json:"total"`
}
// FeedbackReport defines model for FeedbackReport.
type FeedbackReport struct {
// AdminNotes Notes from admin
AdminNotes *string `json:"admin_notes"`
// AssignedToUserId User ID assigned to handle this feedback
AssignedToUserId *int64 `json:"assigned_to_user_id"`
// ContextData Context metadata as JSON object
ContextData *map[string]interface{} `json:"context_data,omitempty"`
// CreatedAt When the feedback was created
CreatedAt time.Time `json:"created_at"`
// FeedbackText Feedback or issue description
FeedbackText string `json:"feedback_text"`
// FeedbackType Type of feedback
FeedbackType FeedbackReportFeedbackType `json:"feedback_type"`
// Id Feedback report ID
Id int64 `json:"id"`
// ResolvedAt When the feedback was resolved
ResolvedAt *time.Time `json:"resolved_at"`
// ResolvedByUserId User ID who resolved the feedback
ResolvedByUserId *int64 `json:"resolved_by_user_id"`
// ScreenshotData Base64 encoded screenshot
ScreenshotData *string `json:"screenshot_data"`
// ScreenshotUrl URL to stored screenshot file
ScreenshotUrl *string `json:"screenshot_url"`
// Status Current status of the feedback
Status FeedbackReportStatus `json:"status"`
// UpdatedAt When the feedback was last updated
UpdatedAt time.Time `json:"updated_at"`
// UserId User ID who submitted the feedback
UserId int64 `json:"user_id"`
}
// FeedbackReportFeedbackType Type of feedback
type FeedbackReportFeedbackType string
// FeedbackReportStatus Current status of the feedback
type FeedbackReportStatus string
// FeedbackSubmissionRequest defines model for FeedbackSubmissionRequest.
type FeedbackSubmissionRequest struct {
// ContextData Context metadata as JSON object
ContextData *map[string]interface{} `json:"context_data,omitempty"`
// FeedbackText Feedback or issue description
FeedbackText string `json:"feedback_text"`
// FeedbackType Type of feedback
FeedbackType *FeedbackSubmissionRequestFeedbackType `json:"feedback_type,omitempty"`
// ScreenshotData Base64 encoded screenshot (optional)
ScreenshotData *[]byte `json:"screenshot_data,omitempty"`
}
// FeedbackSubmissionRequestFeedbackType Type of feedback
type FeedbackSubmissionRequestFeedbackType string
// FeedbackUpdateRequest defines model for FeedbackUpdateRequest.
type FeedbackUpdateRequest struct {
// AdminNotes Admin notes about this feedback
AdminNotes *string `json:"admin_notes,omitempty"`
// AssignedToUserId User ID to assign this feedback to
AssignedToUserId *int64 `json:"assigned_to_user_id,omitempty"`
// ResolvedAt When the feedback was resolved (use current time if status is resolved)
ResolvedAt *time.Time `json:"resolved_at,omitempty"`
// ResolvedByUserId User ID who resolved the feedback
ResolvedByUserId *int64 `json:"resolved_by_user_id,omitempty"`
// Status New status for the feedback
Status *FeedbackUpdateRequestStatus `json:"status,omitempty"`
}
// FeedbackUpdateRequestStatus New status for the feedback
type FeedbackUpdateRequestStatus string
// ForceSendNotificationResponse defines model for ForceSendNotificationResponse.
type ForceSendNotificationResponse struct {
Message *string `json:"message,omitempty"`
Notification *struct {
Status *string `json:"status,omitempty"`
Subject *string `json:"subject,omitempty"`
Type *string `json:"type,omitempty"`
} `json:"notification,omitempty"`
User *struct {
Email *string `json:"email,omitempty"`
Id *int64 `json:"id,omitempty"`
Username *string `json:"username,omitempty"`
} `json:"user,omitempty"`
}
// GeneratingResponse defines model for GeneratingResponse.
type GeneratingResponse struct {
// AiModel User's preferred AI model
AiModel *string `json:"ai_model,omitempty"`
// ApiKey User's API key for the selected provider (write-only)
ApiKey *string `json:"api_key,omitempty"`
Message *string `json:"message,omitempty"`
Status *string `json:"status,omitempty"`
}
// GenerationFocus defines model for GenerationFocus.
type GenerationFocus struct {
// CurrentGenerationModel The AI model currently being used for generation
CurrentGenerationModel *string `json:"current_generation_model,omitempty"`
// GenerationRate Average number of questions generated per minute
GenerationRate *float32 `json:"generation_rate,omitempty"`
// LastGenerationTime Timestamp of the last time a question was generated
LastGenerationTime *string `json:"last_generation_time,omitempty"`
}
// GenerationIntelligence defines model for GenerationIntelligence.
type GenerationIntelligence struct {
GapAnalysis *[]map[string]interface{} `json:"gapAnalysis,omitempty"`
GenerationSuggestions *[]map[string]interface{} `json:"generationSuggestions,omitempty"`
}
// GoogleOAuthLoginResponse defines model for GoogleOAuthLoginResponse.
type GoogleOAuthLoginResponse struct {
// AuthUrl The Google OAuth authorization URL to redirect the user to
AuthUrl string `json:"auth_url"`
}
// Language Learning language (dynamic). Allowed values come from config.yaml language_levels keys.
type Language = string
// LanguageInfo defines model for LanguageInfo.
type LanguageInfo struct {
// Code ISO language code
Code string `json:"code"`
// Name Human-readable language name
Name string `json:"name"`
// TtsLocale TTS locale code for this language
TtsLocale *string `json:"tts_locale,omitempty"`
// TtsVoice Default TTS voice for this language
TtsVoice *string `json:"tts_voice,omitempty"`
}
// LanguagesResponse Array of available learning languages with codes and names
type LanguagesResponse = []LanguageInfo
// Level Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
type Level = string
// LevelsResponse defines model for LevelsResponse.
type LevelsResponse struct {
// LevelDescriptions Mapping from level code to short label (e.g. Beginner, Intermediate)
LevelDescriptions map[string]string `json:"level_descriptions"`
// Levels Array of available language proficiency levels
Levels []string `json:"levels"`
}
// LoginRequest defines model for LoginRequest.
type LoginRequest struct {
// Password Password (minimum 8 characters)
Password string `json:"password"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username string `json:"username"`
}
// LoginResponse defines model for LoginResponse.
type LoginResponse struct {
Message *string `json:"message,omitempty"`
// RedirectUri Redirect URI for OAuth flows (optional)
RedirectUri *string `json:"redirect_uri,omitempty"`
Success *bool `json:"success,omitempty"`
User *User `json:"user,omitempty"`
}
// MarkQuestionKnownRequest defines model for MarkQuestionKnownRequest.
type MarkQuestionKnownRequest struct {
// ConfidenceLevel User's confidence level (1-5, optional)
ConfidenceLevel *int `json:"confidence_level,omitempty"`
}
// NotificationError defines model for NotificationError.
type NotificationError struct {
// EmailAddress Email address that was being used
EmailAddress *string `json:"email_address"`
// ErrorMessage Detailed error message
ErrorMessage *string `json:"error_message,omitempty"`
// ErrorType Type of error that occurred
ErrorType *NotificationErrorErrorType `json:"error_type,omitempty"`
Id *int64 `json:"id,omitempty"`
// NotificationType Type of notification that failed
NotificationType *NotificationErrorNotificationType `json:"notification_type,omitempty"`
// OccurredAt When the error occurred
OccurredAt *string `json:"occurred_at,omitempty"`
// ResolutionNotes Notes about how the error was resolved
ResolutionNotes *string `json:"resolution_notes"`
// ResolvedAt When the error was resolved
ResolvedAt *string `json:"resolved_at"`
UserId *int64 `json:"user_id"`
// Username Username of the user (if available)
Username *string `json:"username,omitempty"`
}
// NotificationErrorErrorType Type of error that occurred
type NotificationErrorErrorType string
// NotificationErrorNotificationType Type of notification that failed
type NotificationErrorNotificationType string
// NotificationErrorStats defines model for NotificationErrorStats.
type NotificationErrorStats struct {
// ErrorsByNotificationType Breakdown of errors by notification type
ErrorsByNotificationType *map[string]int `json:"errors_by_notification_type,omitempty"`
// ErrorsByType Breakdown of errors by type
ErrorsByType *map[string]int `json:"errors_by_type,omitempty"`
// TotalErrors Total number of errors
TotalErrors *int `json:"total_errors,omitempty"`
// UnresolvedErrors Number of unresolved errors
UnresolvedErrors *int `json:"unresolved_errors,omitempty"`
}
// NotificationStats defines model for NotificationStats.
type NotificationStats struct {
// NotificationsByType Breakdown of notifications by type
NotificationsByType *map[string]int `json:"notifications_by_type,omitempty"`
// SentThisWeek Number of notifications sent this week
SentThisWeek *int `json:"sent_this_week,omitempty"`
// SentToday Number of notifications sent today
SentToday *int `json:"sent_today,omitempty"`
// SuccessRate Success rate as a percentage (0-1)
SuccessRate *float32 `json:"success_rate,omitempty"`
// TotalFailed Total number of notifications that failed
TotalFailed *int `json:"total_failed,omitempty"`
// TotalSent Total number of notifications sent
TotalSent *int `json:"total_sent,omitempty"`
}
// PaginationInfo defines model for PaginationInfo.
type PaginationInfo struct {
// Page Current page number
Page int `json:"page"`
// PageSize Number of items per page
PageSize int `json:"page_size"`
// Total Total number of items
Total int `json:"total"`
// TotalPages Total number of pages
TotalPages int `json:"total_pages"`
}
// PasswordResetRequest defines model for PasswordResetRequest.
type PasswordResetRequest struct {
// NewPassword New password (minimum 8 characters)
NewPassword string `json:"new_password"`
}
// PerformanceMetrics defines model for PerformanceMetrics.
type PerformanceMetrics struct {
AverageResponseTimeMs *float32 `json:"average_response_time_ms,omitempty"`
CorrectAttempts *int `json:"correct_attempts,omitempty"`
LastUpdated *string `json:"last_updated,omitempty"`
TotalAttempts *int `json:"total_attempts,omitempty"`
}
// PriorityInsights defines model for PriorityInsights.
type PriorityInsights struct {
// HighPriorityQuestions Number of high-priority questions
HighPriorityQuestions *int `json:"high_priority_questions,omitempty"`
// LowPriorityQuestions Number of low-priority questions
LowPriorityQuestions *int `json:"low_priority_questions,omitempty"`
// MediumPriorityQuestions Number of medium-priority questions
MediumPriorityQuestions *int `json:"medium_priority_questions,omitempty"`
// TotalQuestionsInQueue Total number of questions waiting to be processed
TotalQuestionsInQueue *int `json:"total_questions_in_queue,omitempty"`
}
// Question defines model for Question.
type Question struct {
// ConfidenceLevel Confidence level when question was marked as known (1-5)
ConfidenceLevel *int `json:"confidence_level,omitempty"`
// Content All question types now use multiple choice format with 4 options
Content *QuestionContent `json:"content,omitempty"`
// CorrectAnswer Index of the correct answer in the options array (0-based)
CorrectAnswer *int `json:"correct_answer,omitempty"`
// CorrectCount Number of times this question was answered correctly
CorrectCount *int `json:"correct_count,omitempty"`
CreatedAt *string `json:"created_at,omitempty"`
// DifficultyModifier Difficulty modifier for the question (e.g., basic, intermediate)
DifficultyModifier *string `json:"difficulty_modifier,omitempty"`
DifficultyScore *float32 `json:"difficulty_score,omitempty"`
Explanation *string `json:"explanation,omitempty"`
// GrammarFocus Grammar focus area for the question (e.g., present_perfect, conditionals)
GrammarFocus *string `json:"grammar_focus,omitempty"`
Id *int64 `json:"id,omitempty"`
// IncorrectCount Number of times this question was answered incorrectly
IncorrectCount *int `json:"incorrect_count,omitempty"`
// Language Learning language (dynamic). Allowed values come from config.yaml language_levels keys.
Language *Language `json:"language,omitempty"`
// Level Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
Level *Level `json:"level,omitempty"`
// Reporters Comma-separated list of usernames who reported this question
Reporters *string `json:"reporters,omitempty"`
// Scenario Scenario context for the question (e.g., at_the_airport, in_a_restaurant)
Scenario *string `json:"scenario,omitempty"`
Status *QuestionStatus `json:"status,omitempty"`
// StyleModifier Style modifier for the question (e.g., conversational, formal)
StyleModifier *string `json:"style_modifier,omitempty"`
// TimeContext Time context for the question (e.g., morning_routine, workday)
TimeContext *string `json:"time_context,omitempty"`
// TopicCategory General topic category for question context (e.g., daily_life, travel, work)
TopicCategory *string `json:"topic_category,omitempty"`
// TotalResponses Total number of responses to this question (used for 'Shown' in the UI)
TotalResponses *int `json:"total_responses,omitempty"`
Type *QuestionType `json:"type,omitempty"`
// UserCount Number of users assigned to this question
UserCount *int `json:"user_count,omitempty"`
// VocabularyDomain Vocabulary domain for the question (e.g., food_and_dining, transportation)
VocabularyDomain *string `json:"vocabulary_domain,omitempty"`
}
// QuestionContent All question types now use multiple choice format with 4 options
type QuestionContent struct {
// Hint Optional hint for fill-in-blank questions
Hint *string `json:"hint,omitempty"`
Options []string `json:"options"`
// Passage Only present for reading comprehension questions
Passage *string `json:"passage,omitempty"`
Question string `json:"question"`
// Sentence Only present for vocabulary questions (context sentence)
Sentence *string `json:"sentence,omitempty"`
}
// QuestionStats defines model for QuestionStats.
type QuestionStats struct {
// QuestionsByLanguage Breakdown of questions by language
QuestionsByLanguage *map[string]int `json:"questions_by_language,omitempty"`
// QuestionsByLevel Breakdown of questions by level
QuestionsByLevel *map[string]int `json:"questions_by_level,omitempty"`
// QuestionsByType Breakdown of questions by type
QuestionsByType *map[string]int `json:"questions_by_type,omitempty"`
// TotalQuestions Total number of questions
TotalQuestions *int `json:"total_questions,omitempty"`
// TotalResponses Total number of responses
TotalResponses *int `json:"total_responses,omitempty"`
}
// QuestionStatus defines model for QuestionStatus.
type QuestionStatus string
// QuestionType defines model for QuestionType.
type QuestionType string
// QuizChatRequest defines model for QuizChatRequest.
type QuizChatRequest struct {
AnswerContext *AnswerResponse `json:"answer_context,omitempty"`
// ConversationHistory Previous messages in the conversation
ConversationHistory *[]ChatMessage `json:"conversation_history,omitempty"`
Question Question `json:"question"`
// UserMessage The user's message to the AI tutor.
UserMessage string `json:"user_message"`
}
// ReportQuestionRequest defines model for ReportQuestionRequest.
type ReportQuestionRequest struct {
// ReportReason Optional explanation for why the question is being reported
ReportReason *string `json:"report_reason,omitempty"`
}
// Role defines model for Role.
type Role struct {
// CreatedAt When the role was created
CreatedAt string `json:"created_at"`
// Description Role description
Description string `json:"description"`
// Id Role ID
Id int64 `json:"id"`
// Name Role name (e.g., "user", "admin")
Name string `json:"name"`
// UpdatedAt When the role was last updated
UpdatedAt string `json:"updated_at"`
}
// SentNotification defines model for SentNotification.
type SentNotification struct {
// EmailAddress Email address the notification was sent to
EmailAddress *string `json:"email_address,omitempty"`
// ErrorMessage Error message if the notification failed
ErrorMessage *string `json:"error_message"`
Id *int64 `json:"id,omitempty"`
// NotificationType Type of notification
NotificationType *SentNotificationNotificationType `json:"notification_type,omitempty"`
// RetryCount Number of times the notification was retried
RetryCount *int `json:"retry_count,omitempty"`
// SentAt When the notification was sent
SentAt *string `json:"sent_at,omitempty"`
// Status Status of the notification
Status *SentNotificationStatus `json:"status,omitempty"`
// Subject Subject line of the email
Subject *string `json:"subject,omitempty"`
// TemplateName Template used for the notification
TemplateName *string `json:"template_name,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
// Username Username of the user
Username *string `json:"username,omitempty"`
}
// SentNotificationNotificationType Type of notification
type SentNotificationNotificationType string
// SentNotificationStatus Status of the notification
type SentNotificationStatus string
// ServiceUsageStatsResponse defines model for ServiceUsageStatsResponse.
type ServiceUsageStatsResponse struct {
Data []struct {
// CharactersUsed Number of characters processed
CharactersUsed *int `json:"characters_used,omitempty"`
// Month First day of the month (YYYY-MM)
Month *string `json:"month,omitempty"`
// Quota Monthly quota for this service
Quota *int `json:"quota,omitempty"`
// RequestsMade Number of requests made
RequestsMade *int `json:"requests_made,omitempty"`
// UsageType Type of usage (e.g., "translation")
UsageType *string `json:"usage_type,omitempty"`
} `json:"data"`
// Service Name of the service
Service string `json:"service"`
}
// ServiceVersion defines model for ServiceVersion.
type ServiceVersion struct {
// BuildTime Build timestamp (ISO8601)
BuildTime string `json:"buildTime"`
// Commit Git commit hash
Commit string `json:"commit"`
// Service Service name (e.g., 'backend', 'worker')
Service string `json:"service"`
// Version Version string (e.g., git tag or 'dev')
Version string `json:"version"`
}
// SignupStatusResponse defines model for SignupStatusResponse.
type SignupStatusResponse struct {
// SignupsDisabled Whether user signups are currently disabled
SignupsDisabled bool `json:"signups_disabled"`
}
// Snippet defines model for Snippet.
type Snippet struct {
Context *string `json:"context"`
CreatedAt *time.Time `json:"created_at,omitempty"`
// DifficultyLevel CEFR level (A1, A2, B1, B2, C1, C2)
DifficultyLevel *string `json:"difficulty_level"`
Id *int64 `json:"id,omitempty"`
OriginalText *string `json:"original_text,omitempty"`
QuestionId *int64 `json:"question_id"`
// SectionId ID of the story section where this snippet was created
SectionId *int64 `json:"section_id"`
SourceLanguage *string `json:"source_language,omitempty"`
// StoryId ID of the story where this snippet was created
StoryId *int64 `json:"story_id"`
TargetLanguage *string `json:"target_language,omitempty"`
TranslatedText *string `json:"translated_text,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
}
// SnippetList defines model for SnippetList.
type SnippetList struct {
// Limit Number of snippets returned
Limit *int `json:"limit,omitempty"`
// Offset Number of snippets skipped
Offset *int `json:"offset,omitempty"`
// Query The search query that was used (if any)
Query *string `json:"query"`
Snippets *[]Snippet `json:"snippets,omitempty"`
// Total Total number of snippets matching the query
Total *int `json:"total,omitempty"`
}
// Story defines model for Story.
type Story struct {
AuthorStyle *string `json:"author_style"`
// AutoGenerationPaused When true, the worker will skip automatic section generation for this story
AutoGenerationPaused *bool `json:"auto_generation_paused,omitempty"`
CharacterNames *string `json:"character_names"`
CreatedAt *time.Time `json:"created_at,omitempty"`
CustomInstructions *string `json:"custom_instructions"`
ExtraGenerationsToday *int `json:"extra_generations_today,omitempty"`
Genre *string `json:"genre"`
Id *int64 `json:"id,omitempty"`
Language *string `json:"language,omitempty"`
LastSectionGeneratedAt *time.Time `json:"last_section_generated_at"`
SectionLengthOverride *StorySectionLengthOverride `json:"section_length_override,omitempty"`
Status *StoryStatus `json:"status,omitempty"`
Subject *string `json:"subject"`
TimePeriod *string `json:"time_period"`
Title *string `json:"title,omitempty"`
Tone *string `json:"tone"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
}
// StorySectionLengthOverride defines model for Story.SectionLengthOverride.
type StorySectionLengthOverride string
// StoryStatus defines model for Story.Status.
type StoryStatus string
// StorySection defines model for StorySection.
type StorySection struct {
Content *string `json:"content,omitempty"`
GeneratedAt *time.Time `json:"generated_at,omitempty"`
GenerationDate *openapi_types.Date `json:"generation_date,omitempty"`
Id *int64 `json:"id,omitempty"`
LanguageLevel *string `json:"language_level,omitempty"`
SectionNumber *int `json:"section_number,omitempty"`
StoryId *int64 `json:"story_id,omitempty"`
WordCount *int `json:"word_count,omitempty"`
}
// StorySectionQuestion defines model for StorySectionQuestion.
type StorySectionQuestion struct {
CorrectAnswerIndex *int `json:"correct_answer_index,omitempty"`
CreatedAt *time.Time `json:"created_at,omitempty"`
Explanation *string `json:"explanation"`
Id *int64 `json:"id,omitempty"`
Options *[]string `json:"options,omitempty"`
QuestionText *string `json:"question_text,omitempty"`
SectionId *int64 `json:"section_id,omitempty"`
}
// StorySectionWithQuestions defines model for StorySectionWithQuestions.
type StorySectionWithQuestions struct {
Content *string `json:"content,omitempty"`
GeneratedAt *time.Time `json:"generated_at,omitempty"`
GenerationDate *openapi_types.Date `json:"generation_date,omitempty"`
Id *int64 `json:"id,omitempty"`
LanguageLevel *string `json:"language_level,omitempty"`
Questions *[]StorySectionQuestion `json:"questions,omitempty"`
SectionNumber *int `json:"section_number,omitempty"`
StoryId *int64 `json:"story_id,omitempty"`
WordCount *int `json:"word_count,omitempty"`
}
// StoryWithSections defines model for StoryWithSections.
type StoryWithSections struct {
AuthorStyle *string `json:"author_style"`
// AutoGenerationPaused When true, the worker will skip automatic section generation for this story
AutoGenerationPaused *bool `json:"auto_generation_paused,omitempty"`
CharacterNames *string `json:"character_names"`
CreatedAt *time.Time `json:"created_at,omitempty"`
CustomInstructions *string `json:"custom_instructions"`
ExtraGenerationsToday *int `json:"extra_generations_today,omitempty"`
Genre *string `json:"genre"`
Id *int64 `json:"id,omitempty"`
Language *string `json:"language,omitempty"`
LastSectionGeneratedAt *time.Time `json:"last_section_generated_at"`
SectionLengthOverride *StoryWithSectionsSectionLengthOverride `json:"section_length_override,omitempty"`
Sections *[]StorySection `json:"sections,omitempty"`
Status *StoryWithSectionsStatus `json:"status,omitempty"`
Subject *string `json:"subject"`
TimePeriod *string `json:"time_period"`
Title *string `json:"title,omitempty"`
Tone *string `json:"tone"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
}
// StoryWithSectionsSectionLengthOverride defines model for StoryWithSections.SectionLengthOverride.
type StoryWithSectionsSectionLengthOverride string
// StoryWithSectionsStatus defines model for StoryWithSections.Status.
type StoryWithSectionsStatus string
// SuccessResponse defines model for SuccessResponse.
type SuccessResponse struct {
Message *string `json:"message,omitempty"`
Success bool `json:"success"`
}
// SystemHealthAnalytics defines model for SystemHealthAnalytics.
type SystemHealthAnalytics struct {
BackgroundJobs *map[string]interface{} `json:"backgroundJobs,omitempty"`
Performance *map[string]interface{} `json:"performance,omitempty"`
}
// TTSRequest defines model for TTSRequest.
type TTSRequest struct {
// Input The text to convert to speech
Input string `json:"input"`
// Model The TTS model to use
Model *string `json:"model,omitempty"`
// StreamFormat The format for streaming audio data
StreamFormat *TTSRequestStreamFormat `json:"stream_format,omitempty"`
// Voice The voice to use for speech generation
Voice *string `json:"voice,omitempty"`
}
// TTSRequestStreamFormat The format for streaming audio data
type TTSRequestStreamFormat string
// TTSResponse defines model for TTSResponse.
type TTSResponse struct {
// Audio Base64 encoded audio chunk (for type=audio)
Audio *string `json:"audio,omitempty"`
// Error Error message (for type=error)
Error *string `json:"error,omitempty"`
// Type The type of SSE event
Type *TTSResponseType `json:"type,omitempty"`
// Usage Usage statistics (for type=usage)
Usage *struct {
// InputTokens Number of input tokens processed
InputTokens *int `json:"input_tokens,omitempty"`
// OutputTokens Number of output tokens generated
OutputTokens *int `json:"output_tokens,omitempty"`
// TotalTokens Total tokens used
TotalTokens *int `json:"total_tokens,omitempty"`
} `json:"usage,omitempty"`
}
// TTSResponseType The type of SSE event
type TTSResponseType string
// TestAIRequest defines model for TestAIRequest.
type TestAIRequest struct {
// ApiKey API key for the provider. If not provided, the server will try to use a saved key.
ApiKey *string `json:"api_key"`
// Model AI model code (e.g., "llama3", "gpt-4")
Model string `json:"model"`
// Provider AI provider code (e.g., "ollama", "openai")
Provider string `json:"provider"`
}
// ToggleAutoGenerationRequest defines model for ToggleAutoGenerationRequest.
type ToggleAutoGenerationRequest struct {
// Paused Whether to pause (true) or resume (false) auto-generation
Paused bool `json:"paused"`
}
// ToggleAutoGenerationResponse defines model for ToggleAutoGenerationResponse.
type ToggleAutoGenerationResponse struct {
AutoGenerationPaused *bool `json:"auto_generation_paused,omitempty"`
Message *string `json:"message,omitempty"`
}
// TranslateRequest defines model for TranslateRequest.
type TranslateRequest struct {
// SourceLanguage Source language code (optional - will be auto-detected if not provided)
SourceLanguage *string `json:"source_language,omitempty"`
// TargetLanguage Target language code (e.g., 'en', 'es', 'fr')
TargetLanguage string `json:"target_language"`
// Text Text to translate
Text string `json:"text"`
}
// TranslateResponse defines model for TranslateResponse.
type TranslateResponse struct {
// Confidence Translation confidence score (if available from provider)
Confidence *float32 `json:"confidence,omitempty"`
// SourceLanguage Detected or provided source language code
SourceLanguage string `json:"source_language"`
// TargetLanguage Target language code that was requested
TargetLanguage string `json:"target_language"`
// TranslatedText The translated text
TranslatedText string `json:"translated_text"`
}
// UpdateConversationRequest defines model for UpdateConversationRequest.
type UpdateConversationRequest struct {
// Title New title for the conversation
Title string `json:"title"`
}
// UpdateSnippetRequest defines model for UpdateSnippetRequest.
type UpdateSnippetRequest struct {
// Context User-provided context or notes about this snippet
Context *string `json:"context"`
// OriginalText The original text/word to save
OriginalText *string `json:"original_text,omitempty"`
// SourceLanguage ISO language code of the source text
SourceLanguage *string `json:"source_language,omitempty"`
// TargetLanguage ISO language code of the target translation
TargetLanguage *string `json:"target_language,omitempty"`
// TranslatedText The translated text
TranslatedText *string `json:"translated_text,omitempty"`
}
// UsageStatsResponse defines model for UsageStatsResponse.
type UsageStatsResponse struct {
// CacheStats Cache performance statistics across all services
CacheStats *struct {
// CacheHitRate Cache hit rate as a percentage
CacheHitRate *float32 `json:"cache_hit_rate,omitempty"`
// TotalCacheHitsCharacters Total characters served from cache
TotalCacheHitsCharacters *int `json:"total_cache_hits_characters,omitempty"`
// TotalCacheHitsRequests Total number of cache hit requests
TotalCacheHitsRequests *int `json:"total_cache_hits_requests,omitempty"`
// TotalCacheMissesRequests Total number of cache miss requests
TotalCacheMissesRequests *int `json:"total_cache_misses_requests,omitempty"`
} `json:"cache_stats,omitempty"`
// MonthlyTotals Monthly totals organized by month (YYYY-MM) and service
MonthlyTotals map[string]map[string]struct {
TotalCharacters *int `json:"total_characters,omitempty"`
TotalRequests *int `json:"total_requests,omitempty"`
} `json:"monthly_totals"`
// Services List of service names
Services []string `json:"services"`
// UsageStats Usage statistics organized by service, month (YYYY-MM), and usage type
UsageStats map[string]map[string]struct {
CharactersUsed *int `json:"characters_used,omitempty"`
Quota *int `json:"quota,omitempty"`
RequestsMade *int `json:"requests_made,omitempty"`
} `json:"usage_stats"`
}
// User defines model for User.
type User struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled"`
AiModel *string `json:"ai_model"`
AiProvider *string `json:"ai_provider"`
CreatedAt *string `json:"created_at,omitempty"`
CurrentLevel *string `json:"current_level"`
Email *string `json:"email"`
// HasApiKey Whether the user has a valid API key saved for their current AI provider
HasApiKey *bool `json:"has_api_key,omitempty"`
Id *int64 `json:"id,omitempty"`
// IsPaused Whether the user is paused (question generation disabled)
IsPaused *bool `json:"is_paused,omitempty"`
LastActive *string `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
// Roles List of roles assigned to the user
Roles *[]Role `json:"roles,omitempty"`
Timezone *string `json:"timezone"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username *string `json:"username,omitempty"`
// WordOfDayEmailEnabled Whether the user has enabled Word of the Day emails
WordOfDayEmailEnabled *bool `json:"word_of_day_email_enabled,omitempty"`
}
// UserCreateRequest defines model for UserCreateRequest.
type UserCreateRequest struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
// CurrentLevel Current proficiency level
CurrentLevel *string `json:"current_level,omitempty"`
// Email Email address
Email *openapi_types.Email `json:"email,omitempty"`
// Password Password (minimum 8 characters)
Password string `json:"password"`
// PreferredLanguage Preferred learning language
PreferredLanguage *string `json:"preferred_language,omitempty"`
// Timezone Timezone (e.g., "UTC", "America/New_York")
Timezone *string `json:"timezone,omitempty"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username string `json:"username"`
}
// UserIdRequest defines model for UserIdRequest.
type UserIdRequest struct {
// UserId ID of the user
UserId int64 `json:"user_id"`
}
// UserLearningPreferences defines model for UserLearningPreferences.
type UserLearningPreferences struct {
// DailyGoal User-configurable number of daily questions
DailyGoal *int `json:"daily_goal,omitempty"`
// DailyReminderEnabled Whether to receive daily reminder emails
DailyReminderEnabled bool `json:"daily_reminder_enabled"`
// FocusOnWeakAreas Whether to focus on weak areas
FocusOnWeakAreas bool `json:"focus_on_weak_areas"`
// FreshQuestionRatio Ratio of fresh (never seen) questions to show (0-1)
FreshQuestionRatio float32 `json:"fresh_question_ratio"`
// KnownQuestionPenalty Penalty multiplier for questions marked as known (0-1)
KnownQuestionPenalty float32 `json:"known_question_penalty"`
// ReviewIntervalDays Days between reviews of known questions
ReviewIntervalDays int `json:"review_interval_days"`
// TtsVoice Preferred TTS voice (e.g., it-IT-IsabellaNeural)
TtsVoice *string `json:"tts_voice,omitempty"`
// WeakAreaBoost Multiplier for weak area questions
WeakAreaBoost float32 `json:"weak_area_boost"`
}
// UserPerformanceAnalytics defines model for UserPerformanceAnalytics.
type UserPerformanceAnalytics struct {
LearningPreferences *map[string]interface{} `json:"learningPreferences,omitempty"`
WeakAreas *[]map[string]interface{} `json:"weakAreas,omitempty"`
}
// UserProfile defines model for UserProfile.
type UserProfile struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled"`
CreatedAt *string `json:"created_at,omitempty"`
CurrentLevel *string `json:"current_level,omitempty"`
Email *string `json:"email"`
Id *int64 `json:"id,omitempty"`
// IsPaused Whether the user is paused (question generation disabled)
IsPaused *bool `json:"is_paused,omitempty"`
LastActive *string `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
Timezone *string `json:"timezone"`
UpdatedAt *string `json:"updated_at,omitempty"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username *string `json:"username,omitempty"`
// WordOfDayEmailEnabled Whether the user has enabled Word of the Day emails
WordOfDayEmailEnabled *bool `json:"word_of_day_email_enabled,omitempty"`
}
// UserProgress defines model for UserProgress.
type UserProgress struct {
AccuracyRate *float32 `json:"accuracy_rate,omitempty"`
CorrectAnswers *int `json:"correct_answers,omitempty"`
// CurrentLevel Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
CurrentLevel *Level `json:"current_level,omitempty"`
// GapAnalysis Analysis of learning gaps and areas needing attention
GapAnalysis *map[string]interface{} `json:"gap_analysis,omitempty"`
GenerationFocus *GenerationFocus `json:"generation_focus,omitempty"`
// HighPriorityTopics Topics that have high priority scores for the user
HighPriorityTopics *[]string `json:"high_priority_topics,omitempty"`
LearningPreferences *UserLearningPreferences `json:"learning_preferences,omitempty"`
PerformanceByTopic *map[string]PerformanceMetrics `json:"performance_by_topic,omitempty"`
// PriorityDistribution Distribution of question priorities (high, medium, low counts)
PriorityDistribution *map[string]int `json:"priority_distribution,omitempty"`
PriorityInsights *PriorityInsights `json:"priority_insights,omitempty"`
RecentActivity *[]UserResponse `json:"recent_activity,omitempty"`
// SuggestedLevel Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
SuggestedLevel *Level `json:"suggested_level,omitempty"`
TotalQuestions *int `json:"total_questions,omitempty"`
WeakAreas *[]string `json:"weak_areas,omitempty"`
WorkerStatus *WorkerStatus `json:"worker_status,omitempty"`
}
// UserQuestionStats defines model for UserQuestionStats.
type UserQuestionStats struct {
AccuracyByLevel *map[string]float32 `json:"accuracy_by_level,omitempty"`
AccuracyByType *map[string]float32 `json:"accuracy_by_type,omitempty"`
AnsweredByLevel *map[string]int `json:"answered_by_level,omitempty"`
AnsweredByType *map[string]int `json:"answered_by_type,omitempty"`
AvailableByLevel *map[string]int `json:"available_by_level,omitempty"`
AvailableByType *map[string]int `json:"available_by_type,omitempty"`
TotalAnswered *int `json:"total_answered,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
}
// UserResponse defines model for UserResponse.
type UserResponse struct {
CreatedAt *string `json:"created_at,omitempty"`
IsCorrect *bool `json:"is_correct,omitempty"`
QuestionId *int64 `json:"question_id,omitempty"`
}
// UserSettings defines model for UserSettings.
type UserSettings struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
AiModel *string `json:"ai_model,omitempty"`
AiProvider *string `json:"ai_provider,omitempty"`
// ApiKey API key for AI provider (write-only)
ApiKey *string `json:"api_key,omitempty"`
// Language Learning language (dynamic). Allowed values come from config.yaml language_levels keys.
Language *Language `json:"language,omitempty"`
// Level Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
Level *Level `json:"level,omitempty"`
union json.RawMessage
}
// UserSettings0 defines model for .
type UserSettings0 = interface{}
// UserSettings1 defines model for .
type UserSettings1 = interface{}
// UserUpdateRequest defines model for UserUpdateRequest.
type UserUpdateRequest struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
// AiModel AI model code
AiModel *string `json:"ai_model,omitempty"`
// AiProvider AI provider code
AiProvider *string `json:"ai_provider,omitempty"`
// ApiKey API key for AI provider (write-only)
ApiKey *string `json:"api_key,omitempty"`
// CurrentLevel Current proficiency level
CurrentLevel *string `json:"current_level,omitempty"`
// Email Email address
Email *openapi_types.Email `json:"email,omitempty"`
// PreferredLanguage Preferred learning language
PreferredLanguage *string `json:"preferred_language,omitempty"`
// SelectedRoles Array of role names to assign to the user
SelectedRoles *[]string `json:"selectedRoles,omitempty"`
// Timezone Timezone (e.g., "UTC", "America/New_York")
Timezone *string `json:"timezone,omitempty"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username *string `json:"username,omitempty"`
union json.RawMessage
}
// UserUpdateRequest0 defines model for .
type UserUpdateRequest0 = interface{}
// UserUpdateRequest1 defines model for .
type UserUpdateRequest1 = interface{}
// UserUsageStats defines model for UserUsageStats.
type UserUsageStats struct {
ApiKeyId *int64 `json:"api_key_id,omitempty"`
CompletionTokens *int `json:"completion_tokens,omitempty"`
CreatedAt *string `json:"created_at,omitempty"`
Id *int64 `json:"id,omitempty"`
Model *string `json:"model,omitempty"`
PromptTokens *int `json:"prompt_tokens,omitempty"`
Provider *string `json:"provider,omitempty"`
RequestsMade *int `json:"requests_made,omitempty"`
ServiceName *string `json:"service_name,omitempty"`
TotalTokens *int `json:"total_tokens,omitempty"`
UpdatedAt *string `json:"updated_at,omitempty"`
UsageDate *openapi_types.Date `json:"usage_date,omitempty"`
UsageHour *int `json:"usage_hour,omitempty"`
UsageType *string `json:"usage_type,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
}
// UserUsageStatsDaily defines model for UserUsageStatsDaily.
type UserUsageStatsDaily struct {
Model *string `json:"model,omitempty"`
Provider *string `json:"provider,omitempty"`
ServiceName *string `json:"service_name,omitempty"`
TotalCompletionTokens *int `json:"total_completion_tokens,omitempty"`
TotalPromptTokens *int `json:"total_prompt_tokens,omitempty"`
TotalRequests *int `json:"total_requests,omitempty"`
TotalTokens *int `json:"total_tokens,omitempty"`
UsageDate *openapi_types.Date `json:"usage_date,omitempty"`
UsageType *string `json:"usage_type,omitempty"`
}
// UserUsageStatsHourly defines model for UserUsageStatsHourly.
type UserUsageStatsHourly struct {
Model *string `json:"model,omitempty"`
Provider *string `json:"provider,omitempty"`
ServiceName *string `json:"service_name,omitempty"`
TotalCompletionTokens *int `json:"total_completion_tokens,omitempty"`
TotalPromptTokens *int `json:"total_prompt_tokens,omitempty"`
TotalRequests *int `json:"total_requests,omitempty"`
TotalTokens *int `json:"total_tokens,omitempty"`
UsageHour *int `json:"usage_hour,omitempty"`
UsageType *string `json:"usage_type,omitempty"`
}
// WordOfDayEmailPreferenceRequest defines model for WordOfDayEmailPreferenceRequest.
type WordOfDayEmailPreferenceRequest struct {
// Enabled Whether to enable Word of the Day emails
Enabled bool `json:"enabled"`
}
// WordOfTheDayDisplay defines model for WordOfTheDayDisplay.
type WordOfTheDayDisplay struct {
// Context Additional context for the word (primarily for snippets)
Context *string `json:"context"`
// Date Date for the word of the day (YYYY-MM-DD)
Date openapi_types.Date `json:"date"`
// Explanation Explanation of the word meaning or usage
Explanation *string `json:"explanation"`
// Language Source language of the word
Language string `json:"language"`
// Level CEFR difficulty level
Level *string `json:"level"`
// Sentence Example sentence using the word in context
Sentence string `json:"sentence"`
// SourceId ID of the source (question ID or snippet ID)
SourceId int64 `json:"source_id"`
// SourceType Source type of the word (from vocabulary question or user snippet)
SourceType WordOfTheDayDisplaySourceType `json:"source_type"`
// TopicCategory Topic category for the word
TopicCategory *string `json:"topic_category"`
// Translation English translation of the word
Translation string `json:"translation"`
// Word The word or phrase being featured
Word string `json:"word"`
}
// WordOfTheDayDisplaySourceType Source type of the word (from vocabulary question or user snippet)
type WordOfTheDayDisplaySourceType string
// WorkerHealth defines model for WorkerHealth.
type WorkerHealth struct {
GlobalPaused *bool `json:"global_paused,omitempty"`
HealthyCount *int `json:"healthy_count,omitempty"`
TotalCount *int `json:"total_count,omitempty"`
WorkerInstances *[]struct {
Healthy *bool `json:"healthy,omitempty"`
IsPaused *bool `json:"is_paused,omitempty"`
IsRunning *bool `json:"is_running,omitempty"`
LastHeartbeat *struct {
Time *string `json:"Time,omitempty"`
Valid *bool `json:"Valid,omitempty"`
} `json:"last_heartbeat,omitempty"`
TotalQuestionsGenerated *int `json:"total_questions_generated,omitempty"`
TotalRuns *int `json:"total_runs,omitempty"`
WorkerInstance *string `json:"worker_instance,omitempty"`
} `json:"worker_instances,omitempty"`
}
// WorkerStatus defines model for WorkerStatus.
type WorkerStatus struct {
// ErrorMessage Error message if the worker is in an error state
ErrorMessage *string `json:"error_message"`
// LastHeartbeat Timestamp of the last heartbeat from the worker
LastHeartbeat *string `json:"last_heartbeat,omitempty"`
// Status Current status of the worker
Status *WorkerStatusStatus `json:"status,omitempty"`
}
// WorkerStatusStatus Current status of the worker
type WorkerStatusStatus string
// WorkerStatusResponse defines model for WorkerStatusResponse.
type WorkerStatusResponse struct {
// ErrorMessage Error message if worker has errors
ErrorMessage string `json:"error_message"`
// GlobalPaused Whether the worker is globally paused
GlobalPaused bool `json:"global_paused"`
// HasErrors Whether the worker has encountered errors
HasErrors bool `json:"has_errors"`
// HealthyWorkers Number of healthy worker instances
HealthyWorkers int `json:"healthy_workers"`
// LastErrorDetails Detailed error information if any
LastErrorDetails string `json:"last_error_details"`
// TotalWorkers Total number of worker instances
TotalWorkers int `json:"total_workers"`
// UserPaused Whether the user's question generation is paused
UserPaused bool `json:"user_paused"`
// WorkerRunning Whether the worker is currently running
WorkerRunning bool `json:"worker_running"`
}
// DeleteV1AdminBackendFeedbackParams defines parameters for DeleteV1AdminBackendFeedback.
type DeleteV1AdminBackendFeedbackParams struct {
// Status Status of feedback reports to delete
Status DeleteV1AdminBackendFeedbackParamsStatus `form:"status" json:"status"`
}
// DeleteV1AdminBackendFeedbackParamsStatus defines parameters for DeleteV1AdminBackendFeedback.
type DeleteV1AdminBackendFeedbackParamsStatus string
// GetV1AdminBackendFeedbackParams defines parameters for GetV1AdminBackendFeedback.
type GetV1AdminBackendFeedbackParams struct {
// Page Page number
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of items per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Status Filter by status
Status *GetV1AdminBackendFeedbackParamsStatus `form:"status,omitempty" json:"status,omitempty"`
// FeedbackType Filter by feedback type
FeedbackType *string `form:"feedback_type,omitempty" json:"feedback_type,omitempty"`
// UserId Filter by user ID
UserId *int `form:"user_id,omitempty" json:"user_id,omitempty"`
}
// GetV1AdminBackendFeedbackParamsStatus defines parameters for GetV1AdminBackendFeedback.
type GetV1AdminBackendFeedbackParamsStatus string
// GetV1AdminBackendQuestionsParams defines parameters for GetV1AdminBackendQuestions.
type GetV1AdminBackendQuestionsParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of questions per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for question content
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Type Filter by question type
Type *QuestionType `form:"type,omitempty" json:"type,omitempty"`
// Status Filter by question status
Status *QuestionStatus `form:"status,omitempty" json:"status,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// UserId Filter by user ID (optional)
UserId *int64 `form:"user_id,omitempty" json:"user_id,omitempty"`
}
// GetV1AdminBackendQuestionsPaginatedParams defines parameters for GetV1AdminBackendQuestionsPaginated.
type GetV1AdminBackendQuestionsPaginatedParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of questions per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for question content
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Type Filter by question type
Type *QuestionType `form:"type,omitempty" json:"type,omitempty"`
// Status Filter by question status
Status *QuestionStatus `form:"status,omitempty" json:"status,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// UserId Filter by user ID (optional)
UserId *int64 `form:"user_id,omitempty" json:"user_id,omitempty"`
}
// PutV1AdminBackendQuestionsIdJSONBody defines parameters for PutV1AdminBackendQuestionsId.
type PutV1AdminBackendQuestionsIdJSONBody struct {
// Content Updated question content
Content map[string]interface{} `json:"content"`
// CorrectAnswer Index of the correct answer
CorrectAnswer *int `json:"correct_answer,omitempty"`
// Explanation Explanation for the correct answer
Explanation string `json:"explanation"`
}
// PostV1AdminBackendQuestionsIdAiFixJSONBody defines parameters for PostV1AdminBackendQuestionsIdAiFix.
type PostV1AdminBackendQuestionsIdAiFixJSONBody struct {
AdditionalContext *string `json:"additional_context,omitempty"`
}
// PostV1AdminBackendQuestionsIdAssignUsersJSONBody defines parameters for PostV1AdminBackendQuestionsIdAssignUsers.
type PostV1AdminBackendQuestionsIdAssignUsersJSONBody struct {
// UserIds Array of user IDs to assign to the question
UserIds []int64 `json:"user_ids"`
}
// PostV1AdminBackendQuestionsIdUnassignUsersJSONBody defines parameters for PostV1AdminBackendQuestionsIdUnassignUsers.
type PostV1AdminBackendQuestionsIdUnassignUsersJSONBody struct {
// UserIds Array of user IDs to unassign from the question
UserIds []int64 `json:"user_ids"`
}
// GetV1AdminBackendReportedQuestionsParams defines parameters for GetV1AdminBackendReportedQuestions.
type GetV1AdminBackendReportedQuestionsParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of questions per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for question content
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Type Filter by question type
Type *QuestionType `form:"type,omitempty" json:"type,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
}
// GetV1AdminBackendStoriesParams defines parameters for GetV1AdminBackendStories.
type GetV1AdminBackendStoriesParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of stories per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for story title
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Status Filter by story status
Status *StoryStatus `form:"status,omitempty" json:"status,omitempty"`
// UserId Filter by user ID (optional)
UserId *int64 `form:"user_id,omitempty" json:"user_id,omitempty"`
}
// PostV1AdminBackendUserzJSONBody defines parameters for PostV1AdminBackendUserz.
type PostV1AdminBackendUserzJSONBody struct {
// AiEnabled Whether AI is enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
// AiModel AI model preference
AiModel *string `json:"ai_model,omitempty"`
// AiProvider AI provider preference
AiProvider *string `json:"ai_provider,omitempty"`
// Email Email address for the new user
Email openapi_types.Email `json:"email"`
// Language Preferred language for the user
Language *string `json:"language,omitempty"`
// Level Current level for the user
Level *string `json:"level,omitempty"`
// Password Password for the new user
Password string `json:"password"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username string `json:"username"`
}
// GetV1AdminBackendUserzPaginatedParams defines parameters for GetV1AdminBackendUserzPaginated.
type GetV1AdminBackendUserzPaginatedParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of users per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for username or email
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Language Filter by preferred language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by current level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// AiProvider Filter by AI provider
AiProvider *string `form:"ai_provider,omitempty" json:"ai_provider,omitempty"`
// AiModel Filter by AI model
AiModel *string `form:"ai_model,omitempty" json:"ai_model,omitempty"`
// AiEnabled Filter by AI enabled status
AiEnabled *GetV1AdminBackendUserzPaginatedParamsAiEnabled `form:"ai_enabled,omitempty" json:"ai_enabled,omitempty"`
// Active Filter by active status (active within 7 days)
Active *GetV1AdminBackendUserzPaginatedParamsActive `form:"active,omitempty" json:"active,omitempty"`
}
// GetV1AdminBackendUserzPaginatedParamsAiEnabled defines parameters for GetV1AdminBackendUserzPaginated.
type GetV1AdminBackendUserzPaginatedParamsAiEnabled string
// GetV1AdminBackendUserzPaginatedParamsActive defines parameters for GetV1AdminBackendUserzPaginated.
type GetV1AdminBackendUserzPaginatedParamsActive string
// PostV1AdminBackendUserzIdRolesJSONBody defines parameters for PostV1AdminBackendUserzIdRoles.
type PostV1AdminBackendUserzIdRolesJSONBody struct {
// RoleId Role ID to assign
RoleId int64 `json:"role_id"`
}
// GetV1AdminWorkerNotificationsErrorsParams defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of errors per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// ErrorType Filter by error type
ErrorType *GetV1AdminWorkerNotificationsErrorsParamsErrorType `form:"error_type,omitempty" json:"error_type,omitempty"`
// NotificationType Filter by notification type
NotificationType *GetV1AdminWorkerNotificationsErrorsParamsNotificationType `form:"notification_type,omitempty" json:"notification_type,omitempty"`
// Resolved Filter by resolution status
Resolved *GetV1AdminWorkerNotificationsErrorsParamsResolved `form:"resolved,omitempty" json:"resolved,omitempty"`
}
// GetV1AdminWorkerNotificationsErrorsParamsErrorType defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParamsErrorType string
// GetV1AdminWorkerNotificationsErrorsParamsNotificationType defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParamsNotificationType string
// GetV1AdminWorkerNotificationsErrorsParamsResolved defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParamsResolved string
// PostV1AdminWorkerNotificationsForceSendJSONBody defines parameters for PostV1AdminWorkerNotificationsForceSend.
type PostV1AdminWorkerNotificationsForceSendJSONBody struct {
// Username Username of the user to send notification to
Username string `json:"username"`
}
// GetV1AdminWorkerNotificationsSentParams defines parameters for GetV1AdminWorkerNotificationsSent.
type GetV1AdminWorkerNotificationsSentParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of notifications per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// NotificationType Filter by notification type
NotificationType *GetV1AdminWorkerNotificationsSentParamsNotificationType `form:"notification_type,omitempty" json:"notification_type,omitempty"`
// Status Filter by status
Status *GetV1AdminWorkerNotificationsSentParamsStatus `form:"status,omitempty" json:"status,omitempty"`
// SentAfter Filter notifications sent after this timestamp
SentAfter *string `form:"sent_after,omitempty" json:"sent_after,omitempty"`
// SentBefore Filter notifications sent before this timestamp
SentBefore *string `form:"sent_before,omitempty" json:"sent_before,omitempty"`
}
// GetV1AdminWorkerNotificationsSentParamsNotificationType defines parameters for GetV1AdminWorkerNotificationsSent.
type GetV1AdminWorkerNotificationsSentParamsNotificationType string
// GetV1AdminWorkerNotificationsSentParamsStatus defines parameters for GetV1AdminWorkerNotificationsSent.
type GetV1AdminWorkerNotificationsSentParamsStatus string
// GetV1AiBookmarksParams defines parameters for GetV1AiBookmarks.
type GetV1AiBookmarksParams struct {
// Q Optional search query to filter bookmarked messages
Q *string `form:"q,omitempty" json:"q,omitempty"`
// Limit Maximum number of messages to return
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
// Offset Number of messages to skip
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
}
// GetV1AiConversationsParams defines parameters for GetV1AiConversations.
type GetV1AiConversationsParams struct {
// Limit Maximum number of conversations to return
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
// Offset Number of conversations to skip
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
}
// PutV1AiConversationsBookmarkJSONBody defines parameters for PutV1AiConversationsBookmark.
type PutV1AiConversationsBookmarkJSONBody struct {
// ConversationId ID of the conversation containing the message
ConversationId openapi_types.UUID `json:"conversation_id"`
// MessageId ID of the message to bookmark/unbookmark
MessageId openapi_types.UUID `json:"message_id"`
}
// GetV1AiSearchParams defines parameters for GetV1AiSearch.
type GetV1AiSearchParams struct {
// Q Search query string
Q string `form:"q" json:"q"`
// Limit Maximum number of results to return
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
// Offset Number of results to skip
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
}
// GetV1AuthGoogleCallbackParams defines parameters for GetV1AuthGoogleCallback.
type GetV1AuthGoogleCallbackParams struct {
// Code Authorization code from Google
Code string `form:"code" json:"code"`
// State State parameter for CSRF protection
State *string `form:"state,omitempty" json:"state,omitempty"`
}
// PostV1DailyQuestionsDateAnswerQuestionIdJSONBody defines parameters for PostV1DailyQuestionsDateAnswerQuestionId.
type PostV1DailyQuestionsDateAnswerQuestionIdJSONBody struct {
// UserAnswerIndex Index of the user's selected answer (0-based)
UserAnswerIndex int `json:"user_answer_index"`
}
// GetV1QuizAiTokenUsageParams defines parameters for GetV1QuizAiTokenUsage.
type GetV1QuizAiTokenUsageParams struct {
// StartDate Start date in YYYY-MM-DD format
StartDate openapi_types.Date `form:"startDate" json:"startDate"`
// EndDate End date in YYYY-MM-DD format
EndDate openapi_types.Date `form:"endDate" json:"endDate"`
}
// GetV1QuizAiTokenUsageDailyParams defines parameters for GetV1QuizAiTokenUsageDaily.
type GetV1QuizAiTokenUsageDailyParams struct {
// StartDate Start date in YYYY-MM-DD format
StartDate openapi_types.Date `form:"startDate" json:"startDate"`
// EndDate End date in YYYY-MM-DD format
EndDate openapi_types.Date `form:"endDate" json:"endDate"`
}
// GetV1QuizAiTokenUsageHourlyParams defines parameters for GetV1QuizAiTokenUsageHourly.
type GetV1QuizAiTokenUsageHourlyParams struct {
// Date Date in YYYY-MM-DD format
Date openapi_types.Date `form:"date" json:"date"`
}
// GetV1QuizQuestionParams defines parameters for GetV1QuizQuestion.
type GetV1QuizQuestionParams struct {
// Language Preferred language for the question
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Difficulty level for the question
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// Type Specific question type(s) to retrieve (comma-separated list). If multiple types are provided, the first valid type will be used.
Type *string `form:"type,omitempty" json:"type,omitempty"`
// ExcludeType Question type(s) to exclude from random selection (comma-separated list). Useful for filtering out specific question types from the general quiz.
ExcludeType *string `form:"exclude_type,omitempty" json:"exclude_type,omitempty"`
}
// GetV1SettingsLevelsParams defines parameters for GetV1SettingsLevels.
type GetV1SettingsLevelsParams struct {
// Language Language to get levels for (optional - returns all levels if not specified)
Language *string `form:"language,omitempty" json:"language,omitempty"`
}
// GetV1SnippetsParams defines parameters for GetV1Snippets.
type GetV1SnippetsParams struct {
// Q Optional search query to filter snippets by text content
Q *string `form:"q,omitempty" json:"q,omitempty"`
// SourceLang Filter by source language
SourceLang *string `form:"source_lang,omitempty" json:"source_lang,omitempty"`
// TargetLang Filter by target language
TargetLang *string `form:"target_lang,omitempty" json:"target_lang,omitempty"`
// StoryId Filter by story ID
StoryId *int64 `form:"story_id,omitempty" json:"story_id,omitempty"`
// Level Filter by difficulty level (CEFR level)
Level *GetV1SnippetsParamsLevel `form:"level,omitempty" json:"level,omitempty"`
// Limit Maximum number of snippets to return (default 50, max 100)
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
// Offset Number of snippets to skip for pagination
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
}
// GetV1SnippetsParamsLevel defines parameters for GetV1Snippets.
type GetV1SnippetsParamsLevel string
// GetV1SnippetsSearchParams defines parameters for GetV1SnippetsSearch.
type GetV1SnippetsSearchParams struct {
// Q Search query string
Q string `form:"q" json:"q"`
// SourceLang Filter results by source language
SourceLang *string `form:"source_lang,omitempty" json:"source_lang,omitempty"`
// Limit Maximum number of results to return
Limit *int `form:"limit,omitempty" json:"limit,omitempty"`
// Offset Number of results to skip
Offset *int `form:"offset,omitempty" json:"offset,omitempty"`
}
// GetV1StoryParams defines parameters for GetV1Story.
type GetV1StoryParams struct {
// IncludeArchived Include archived stories in the response
IncludeArchived *bool `form:"include_archived,omitempty" json:"include_archived,omitempty"`
}
// GetV1WordOfDayEmbedParams defines parameters for GetV1WordOfDayEmbed.
type GetV1WordOfDayEmbedParams struct {
// Date Optional date in YYYY-MM-DD format. Defaults to today's date in the user's timezone when omitted.
Date *openapi_types.Date `form:"date,omitempty" json:"date,omitempty"`
}
// GetV1WordOfDayHistoryParams defines parameters for GetV1WordOfDayHistory.
type GetV1WordOfDayHistoryParams struct {
// StartDate Start date in YYYY-MM-DD format
StartDate openapi_types.Date `form:"start_date" json:"start_date"`
// EndDate End date in YYYY-MM-DD format
EndDate openapi_types.Date `form:"end_date" json:"end_date"`
}
// PatchV1AdminBackendFeedbackIdJSONRequestBody defines body for PatchV1AdminBackendFeedbackId for application/json ContentType.
type PatchV1AdminBackendFeedbackIdJSONRequestBody = FeedbackUpdateRequest
// PutV1AdminBackendQuestionsIdJSONRequestBody defines body for PutV1AdminBackendQuestionsId for application/json ContentType.
type PutV1AdminBackendQuestionsIdJSONRequestBody PutV1AdminBackendQuestionsIdJSONBody
// PostV1AdminBackendQuestionsIdAiFixJSONRequestBody defines body for PostV1AdminBackendQuestionsIdAiFix for application/json ContentType.
type PostV1AdminBackendQuestionsIdAiFixJSONRequestBody PostV1AdminBackendQuestionsIdAiFixJSONBody
// PostV1AdminBackendQuestionsIdAssignUsersJSONRequestBody defines body for PostV1AdminBackendQuestionsIdAssignUsers for application/json ContentType.
type PostV1AdminBackendQuestionsIdAssignUsersJSONRequestBody PostV1AdminBackendQuestionsIdAssignUsersJSONBody
// PostV1AdminBackendQuestionsIdUnassignUsersJSONRequestBody defines body for PostV1AdminBackendQuestionsIdUnassignUsers for application/json ContentType.
type PostV1AdminBackendQuestionsIdUnassignUsersJSONRequestBody PostV1AdminBackendQuestionsIdUnassignUsersJSONBody
// PostV1AdminBackendUserzJSONRequestBody defines body for PostV1AdminBackendUserz for application/json ContentType.
type PostV1AdminBackendUserzJSONRequestBody PostV1AdminBackendUserzJSONBody
// PutV1AdminBackendUserzIdJSONRequestBody defines body for PutV1AdminBackendUserzId for application/json ContentType.
type PutV1AdminBackendUserzIdJSONRequestBody = UserUpdateRequest
// PostV1AdminBackendUserzIdResetPasswordJSONRequestBody defines body for PostV1AdminBackendUserzIdResetPassword for application/json ContentType.
type PostV1AdminBackendUserzIdResetPasswordJSONRequestBody = PasswordResetRequest
// PostV1AdminBackendUserzIdRolesJSONRequestBody defines body for PostV1AdminBackendUserzIdRoles for application/json ContentType.
type PostV1AdminBackendUserzIdRolesJSONRequestBody PostV1AdminBackendUserzIdRolesJSONBody
// PostV1AdminWorkerNotificationsForceSendJSONRequestBody defines body for PostV1AdminWorkerNotificationsForceSend for application/json ContentType.
type PostV1AdminWorkerNotificationsForceSendJSONRequestBody PostV1AdminWorkerNotificationsForceSendJSONBody
// PostV1AdminWorkerUsersPauseJSONRequestBody defines body for PostV1AdminWorkerUsersPause for application/json ContentType.
type PostV1AdminWorkerUsersPauseJSONRequestBody = UserIdRequest
// PostV1AdminWorkerUsersResumeJSONRequestBody defines body for PostV1AdminWorkerUsersResume for application/json ContentType.
type PostV1AdminWorkerUsersResumeJSONRequestBody = UserIdRequest
// PostV1AiConversationsJSONRequestBody defines body for PostV1AiConversations for application/json ContentType.
type PostV1AiConversationsJSONRequestBody = CreateConversationRequest
// PutV1AiConversationsBookmarkJSONRequestBody defines body for PutV1AiConversationsBookmark for application/json ContentType.
type PutV1AiConversationsBookmarkJSONRequestBody PutV1AiConversationsBookmarkJSONBody
// PostV1AiConversationsConversationIdMessagesJSONRequestBody defines body for PostV1AiConversationsConversationIdMessages for application/json ContentType.
type PostV1AiConversationsConversationIdMessagesJSONRequestBody = CreateMessageRequest
// PutV1AiConversationsIdJSONRequestBody defines body for PutV1AiConversationsId for application/json ContentType.
type PutV1AiConversationsIdJSONRequestBody = UpdateConversationRequest
// PostV1ApiKeysJSONRequestBody defines body for PostV1ApiKeys for application/json ContentType.
type PostV1ApiKeysJSONRequestBody = CreateAPIKeyRequest
// PostV1AudioSpeechJSONRequestBody defines body for PostV1AudioSpeech for application/json ContentType.
type PostV1AudioSpeechJSONRequestBody = TTSRequest
// PostV1AudioSpeechInitJSONRequestBody defines body for PostV1AudioSpeechInit for application/json ContentType.
type PostV1AudioSpeechInitJSONRequestBody = TTSRequest
// PostV1AuthLoginJSONRequestBody defines body for PostV1AuthLogin for application/json ContentType.
type PostV1AuthLoginJSONRequestBody = LoginRequest
// PostV1AuthSignupJSONRequestBody defines body for PostV1AuthSignup for application/json ContentType.
type PostV1AuthSignupJSONRequestBody = UserCreateRequest
// PostV1DailyQuestionsDateAnswerQuestionIdJSONRequestBody defines body for PostV1DailyQuestionsDateAnswerQuestionId for application/json ContentType.
type PostV1DailyQuestionsDateAnswerQuestionIdJSONRequestBody PostV1DailyQuestionsDateAnswerQuestionIdJSONBody
// PostV1FeedbackJSONRequestBody defines body for PostV1Feedback for application/json ContentType.
type PostV1FeedbackJSONRequestBody = FeedbackSubmissionRequest
// PutV1PreferencesLearningJSONRequestBody defines body for PutV1PreferencesLearning for application/json ContentType.
type PutV1PreferencesLearningJSONRequestBody = UserLearningPreferences
// PostV1QuizAnswerJSONRequestBody defines body for PostV1QuizAnswer for application/json ContentType.
type PostV1QuizAnswerJSONRequestBody = AnswerRequest
// PostV1QuizChatStreamJSONRequestBody defines body for PostV1QuizChatStream for application/json ContentType.
type PostV1QuizChatStreamJSONRequestBody = QuizChatRequest
// PostV1QuizQuestionIdMarkKnownJSONRequestBody defines body for PostV1QuizQuestionIdMarkKnown for application/json ContentType.
type PostV1QuizQuestionIdMarkKnownJSONRequestBody = MarkQuestionKnownRequest
// PostV1QuizQuestionIdReportJSONRequestBody defines body for PostV1QuizQuestionIdReport for application/json ContentType.
type PostV1QuizQuestionIdReportJSONRequestBody = ReportQuestionRequest
// PutV1SettingsJSONRequestBody defines body for PutV1Settings for application/json ContentType.
type PutV1SettingsJSONRequestBody = UserSettings
// PostV1SettingsTestAiJSONRequestBody defines body for PostV1SettingsTestAi for application/json ContentType.
type PostV1SettingsTestAiJSONRequestBody = TestAIRequest
// PutV1SettingsWordOfDayEmailJSONRequestBody defines body for PutV1SettingsWordOfDayEmail for application/json ContentType.
type PutV1SettingsWordOfDayEmailJSONRequestBody = WordOfDayEmailPreferenceRequest
// PostV1SnippetsJSONRequestBody defines body for PostV1Snippets for application/json ContentType.
type PostV1SnippetsJSONRequestBody = CreateSnippetRequest
// PutV1SnippetsIdJSONRequestBody defines body for PutV1SnippetsId for application/json ContentType.
type PutV1SnippetsIdJSONRequestBody = UpdateSnippetRequest
// PostV1StoryJSONRequestBody defines body for PostV1Story for application/json ContentType.
type PostV1StoryJSONRequestBody = CreateStoryRequest
// PostV1StoryIdGenerateJSONRequestBody defines body for PostV1StoryIdGenerate for application/json ContentType.
type PostV1StoryIdGenerateJSONRequestBody = EmptyRequest
// PostV1StoryIdToggleAutoGenerationJSONRequestBody defines body for PostV1StoryIdToggleAutoGeneration for application/json ContentType.
type PostV1StoryIdToggleAutoGenerationJSONRequestBody = ToggleAutoGenerationRequest
// PostV1TranslateJSONRequestBody defines body for PostV1Translate for application/json ContentType.
type PostV1TranslateJSONRequestBody = TranslateRequest
// PutV1UserzProfileJSONRequestBody defines body for PutV1UserzProfile for application/json ContentType.
type PutV1UserzProfileJSONRequestBody = UserUpdateRequest
// AsServiceVersion returns the union data inside the AggregatedVersion_Worker as a ServiceVersion
func (t AggregatedVersion_Worker) AsServiceVersion() (ServiceVersion, error) {
var body ServiceVersion
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromServiceVersion overwrites any union data inside the AggregatedVersion_Worker as the provided ServiceVersion
func (t *AggregatedVersion_Worker) FromServiceVersion(v ServiceVersion) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeServiceVersion performs a merge with any union data inside the AggregatedVersion_Worker, using the provided ServiceVersion
func (t *AggregatedVersion_Worker) MergeServiceVersion(v ServiceVersion) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
// AsAggregatedVersionWorker1 returns the union data inside the AggregatedVersion_Worker as a AggregatedVersionWorker1
func (t AggregatedVersion_Worker) AsAggregatedVersionWorker1() (AggregatedVersionWorker1, error) {
var body AggregatedVersionWorker1
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromAggregatedVersionWorker1 overwrites any union data inside the AggregatedVersion_Worker as the provided AggregatedVersionWorker1
func (t *AggregatedVersion_Worker) FromAggregatedVersionWorker1(v AggregatedVersionWorker1) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeAggregatedVersionWorker1 performs a merge with any union data inside the AggregatedVersion_Worker, using the provided AggregatedVersionWorker1
func (t *AggregatedVersion_Worker) MergeAggregatedVersionWorker1(v AggregatedVersionWorker1) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t AggregatedVersion_Worker) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
return b, err
}
func (t *AggregatedVersion_Worker) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
return err
}
// AsUserSettings0 returns the union data inside the UserSettings as a UserSettings0
func (t UserSettings) AsUserSettings0() (UserSettings0, error) {
var body UserSettings0
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserSettings0 overwrites any union data inside the UserSettings as the provided UserSettings0
func (t *UserSettings) FromUserSettings0(v UserSettings0) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserSettings0 performs a merge with any union data inside the UserSettings, using the provided UserSettings0
func (t *UserSettings) MergeUserSettings0(v UserSettings0) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
// AsUserSettings1 returns the union data inside the UserSettings as a UserSettings1
func (t UserSettings) AsUserSettings1() (UserSettings1, error) {
var body UserSettings1
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserSettings1 overwrites any union data inside the UserSettings as the provided UserSettings1
func (t *UserSettings) FromUserSettings1(v UserSettings1) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserSettings1 performs a merge with any union data inside the UserSettings, using the provided UserSettings1
func (t *UserSettings) MergeUserSettings1(v UserSettings1) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t UserSettings) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
if err != nil {
return nil, err
}
object := make(map[string]json.RawMessage)
if t.union != nil {
err = json.Unmarshal(b, &object)
if err != nil {
return nil, err
}
}
if t.AiEnabled != nil {
object["ai_enabled"], err = json.Marshal(t.AiEnabled)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_enabled': %w", err)
}
}
if t.AiModel != nil {
object["ai_model"], err = json.Marshal(t.AiModel)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_model': %w", err)
}
}
if t.AiProvider != nil {
object["ai_provider"], err = json.Marshal(t.AiProvider)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_provider': %w", err)
}
}
if t.ApiKey != nil {
object["api_key"], err = json.Marshal(t.ApiKey)
if err != nil {
return nil, fmt.Errorf("error marshaling 'api_key': %w", err)
}
}
if t.Language != nil {
object["language"], err = json.Marshal(t.Language)
if err != nil {
return nil, fmt.Errorf("error marshaling 'language': %w", err)
}
}
if t.Level != nil {
object["level"], err = json.Marshal(t.Level)
if err != nil {
return nil, fmt.Errorf("error marshaling 'level': %w", err)
}
}
b, err = json.Marshal(object)
return b, err
}
func (t *UserSettings) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
if err != nil {
return err
}
object := make(map[string]json.RawMessage)
err = json.Unmarshal(b, &object)
if err != nil {
return err
}
if raw, found := object["ai_enabled"]; found {
err = json.Unmarshal(raw, &t.AiEnabled)
if err != nil {
return fmt.Errorf("error reading 'ai_enabled': %w", err)
}
}
if raw, found := object["ai_model"]; found {
err = json.Unmarshal(raw, &t.AiModel)
if err != nil {
return fmt.Errorf("error reading 'ai_model': %w", err)
}
}
if raw, found := object["ai_provider"]; found {
err = json.Unmarshal(raw, &t.AiProvider)
if err != nil {
return fmt.Errorf("error reading 'ai_provider': %w", err)
}
}
if raw, found := object["api_key"]; found {
err = json.Unmarshal(raw, &t.ApiKey)
if err != nil {
return fmt.Errorf("error reading 'api_key': %w", err)
}
}
if raw, found := object["language"]; found {
err = json.Unmarshal(raw, &t.Language)
if err != nil {
return fmt.Errorf("error reading 'language': %w", err)
}
}
if raw, found := object["level"]; found {
err = json.Unmarshal(raw, &t.Level)
if err != nil {
return fmt.Errorf("error reading 'level': %w", err)
}
}
return err
}
// AsUserUpdateRequest0 returns the union data inside the UserUpdateRequest as a UserUpdateRequest0
func (t UserUpdateRequest) AsUserUpdateRequest0() (UserUpdateRequest0, error) {
var body UserUpdateRequest0
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserUpdateRequest0 overwrites any union data inside the UserUpdateRequest as the provided UserUpdateRequest0
func (t *UserUpdateRequest) FromUserUpdateRequest0(v UserUpdateRequest0) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserUpdateRequest0 performs a merge with any union data inside the UserUpdateRequest, using the provided UserUpdateRequest0
func (t *UserUpdateRequest) MergeUserUpdateRequest0(v UserUpdateRequest0) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
// AsUserUpdateRequest1 returns the union data inside the UserUpdateRequest as a UserUpdateRequest1
func (t UserUpdateRequest) AsUserUpdateRequest1() (UserUpdateRequest1, error) {
var body UserUpdateRequest1
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserUpdateRequest1 overwrites any union data inside the UserUpdateRequest as the provided UserUpdateRequest1
func (t *UserUpdateRequest) FromUserUpdateRequest1(v UserUpdateRequest1) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserUpdateRequest1 performs a merge with any union data inside the UserUpdateRequest, using the provided UserUpdateRequest1
func (t *UserUpdateRequest) MergeUserUpdateRequest1(v UserUpdateRequest1) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t UserUpdateRequest) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
if err != nil {
return nil, err
}
object := make(map[string]json.RawMessage)
if t.union != nil {
err = json.Unmarshal(b, &object)
if err != nil {
return nil, err
}
}
if t.AiEnabled != nil {
object["ai_enabled"], err = json.Marshal(t.AiEnabled)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_enabled': %w", err)
}
}
if t.AiModel != nil {
object["ai_model"], err = json.Marshal(t.AiModel)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_model': %w", err)
}
}
if t.AiProvider != nil {
object["ai_provider"], err = json.Marshal(t.AiProvider)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_provider': %w", err)
}
}
if t.ApiKey != nil {
object["api_key"], err = json.Marshal(t.ApiKey)
if err != nil {
return nil, fmt.Errorf("error marshaling 'api_key': %w", err)
}
}
if t.CurrentLevel != nil {
object["current_level"], err = json.Marshal(t.CurrentLevel)
if err != nil {
return nil, fmt.Errorf("error marshaling 'current_level': %w", err)
}
}
if t.Email != nil {
object["email"], err = json.Marshal(t.Email)
if err != nil {
return nil, fmt.Errorf("error marshaling 'email': %w", err)
}
}
if t.PreferredLanguage != nil {
object["preferred_language"], err = json.Marshal(t.PreferredLanguage)
if err != nil {
return nil, fmt.Errorf("error marshaling 'preferred_language': %w", err)
}
}
if t.SelectedRoles != nil {
object["selectedRoles"], err = json.Marshal(t.SelectedRoles)
if err != nil {
return nil, fmt.Errorf("error marshaling 'selectedRoles': %w", err)
}
}
if t.Timezone != nil {
object["timezone"], err = json.Marshal(t.Timezone)
if err != nil {
return nil, fmt.Errorf("error marshaling 'timezone': %w", err)
}
}
if t.Username != nil {
object["username"], err = json.Marshal(t.Username)
if err != nil {
return nil, fmt.Errorf("error marshaling 'username': %w", err)
}
}
b, err = json.Marshal(object)
return b, err
}
func (t *UserUpdateRequest) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
if err != nil {
return err
}
object := make(map[string]json.RawMessage)
err = json.Unmarshal(b, &object)
if err != nil {
return err
}
if raw, found := object["ai_enabled"]; found {
err = json.Unmarshal(raw, &t.AiEnabled)
if err != nil {
return fmt.Errorf("error reading 'ai_enabled': %w", err)
}
}
if raw, found := object["ai_model"]; found {
err = json.Unmarshal(raw, &t.AiModel)
if err != nil {
return fmt.Errorf("error reading 'ai_model': %w", err)
}
}
if raw, found := object["ai_provider"]; found {
err = json.Unmarshal(raw, &t.AiProvider)
if err != nil {
return fmt.Errorf("error reading 'ai_provider': %w", err)
}
}
if raw, found := object["api_key"]; found {
err = json.Unmarshal(raw, &t.ApiKey)
if err != nil {
return fmt.Errorf("error reading 'api_key': %w", err)
}
}
if raw, found := object["current_level"]; found {
err = json.Unmarshal(raw, &t.CurrentLevel)
if err != nil {
return fmt.Errorf("error reading 'current_level': %w", err)
}
}
if raw, found := object["email"]; found {
err = json.Unmarshal(raw, &t.Email)
if err != nil {
return fmt.Errorf("error reading 'email': %w", err)
}
}
if raw, found := object["preferred_language"]; found {
err = json.Unmarshal(raw, &t.PreferredLanguage)
if err != nil {
return fmt.Errorf("error reading 'preferred_language': %w", err)
}
}
if raw, found := object["selectedRoles"]; found {
err = json.Unmarshal(raw, &t.SelectedRoles)
if err != nil {
return fmt.Errorf("error reading 'selectedRoles': %w", err)
}
}
if raw, found := object["timezone"]; found {
err = json.Unmarshal(raw, &t.Timezone)
if err != nil {
return fmt.Errorf("error reading 'timezone': %w", err)
}
}
if raw, found := object["username"]; found {
err = json.Unmarshal(raw, &t.Username)
if err != nil {
return fmt.Errorf("error reading 'username': %w", err)
}
}
return err
}
// Package config handles application configuration loading from environment variables.
package config
import (
"fmt"
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
contextutils "quizapp/internal/utils"
"gopkg.in/yaml.v3"
)
// ProviderConfig defines the structure for a single provider
type ProviderConfig struct {
Name string `json:"name" yaml:"name"`
Code string `json:"code" yaml:"code"`
URL string `json:"url,omitempty" yaml:"url,omitempty"`
SupportsGrammar bool `json:"supports_grammar" yaml:"supports_grammar"`
UsageSupported bool `json:"usage_supported" yaml:"usage_supported"`
QuestionBatchSize int `json:"question_batch_size,omitempty" yaml:"question_batch_size,omitempty"`
Models []AIModel `json:"models" yaml:"models"`
}
// AIModel represents an AI model configuration
type AIModel struct {
Name string `json:"name" yaml:"name"`
Code string `json:"code" yaml:"code"`
MaxTokens int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
}
// QuestionVarietyConfig defines the variety configuration for question generation
type QuestionVarietyConfig struct {
TopicCategories []string `json:"topic_categories" yaml:"topic_categories"`
GrammarFocusByLevel map[string][]string `json:"grammar_focus_by_level" yaml:"grammar_focus_by_level"`
GrammarFocus []string `json:"grammar_focus" yaml:"grammar_focus"`
VocabularyDomains []string `json:"vocabulary_domains" yaml:"vocabulary_domains"`
Scenarios []string `json:"scenarios" yaml:"scenarios"`
StyleModifiers []string `json:"style_modifiers" yaml:"style_modifiers"`
DifficultyModifiers []string `json:"difficulty_modifiers" yaml:"difficulty_modifiers"`
TimeContexts []string `json:"time_contexts" yaml:"time_contexts"`
}
// LanguageLevelConfig represents the levels and descriptions for a specific language
type LanguageLevelConfig struct {
Code string `json:"code" yaml:"code"`
TtsLocale string `json:"tts_locale" yaml:"tts_locale"`
TtsVoice string `json:"tts_voice" yaml:"tts_voice"`
Levels []string `json:"levels" yaml:"levels"`
Descriptions map[string]string `json:"descriptions" yaml:"descriptions"`
}
// LanguageInfo represents a language with its code and human-readable name
type LanguageInfo struct {
Code string `json:"code"`
Name string `json:"name"`
TtsLocale *string `json:"tts_locale,omitempty"`
TtsVoice *string `json:"tts_voice,omitempty"`
}
// AuthConfig represents authentication-related configuration
type AuthConfig struct {
SignupsDisabled bool `json:"signups_disabled" yaml:"signups_disabled"`
AllowedDomains []string `json:"allowed_domains,omitempty" yaml:"allowed_domains,omitempty"`
AllowedEmails []string `json:"allowed_emails,omitempty" yaml:"allowed_emails,omitempty"`
}
// SystemConfig represents system-wide configuration
type SystemConfig struct {
Auth AuthConfig `json:"auth" yaml:"auth"`
}
// Config holds all configuration for the application
type Config struct {
// Server configuration
Server ServerConfig `json:"server" yaml:"server"`
// Database configuration
Database DatabaseConfig `json:"database" yaml:"database"`
// AI Providers and Language Levels
Providers []ProviderConfig `json:"providers" yaml:"providers"`
LanguageLevels map[string]LanguageLevelConfig `json:"language_levels" yaml:"language_levels"`
Variety *QuestionVarietyConfig `json:"variety,omitempty" yaml:"variety,omitempty"`
System *SystemConfig `json:"system,omitempty" yaml:"system,omitempty"`
// OAuth Configuration
GoogleOAuthClientID string `json:"google_oauth_client_id" yaml:"google_oauth_client_id"`
GoogleOAuthClientSecret string `json:"google_oauth_client_secret" yaml:"google_oauth_client_secret"`
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url" yaml:"google_oauth_redirect_url"`
// OpenTelemetry Configuration
OpenTelemetry OpenTelemetryConfig `json:"open_telemetry" yaml:"open_telemetry"`
// Email Configuration
Email EmailConfig `json:"email" yaml:"email"`
// Story Configuration
Story StoryConfig `json:"story" yaml:"story"`
// Translation Configuration
Translation TranslationConfig `json:"translation" yaml:"translation"`
// Linear Configuration
Linear LinearConfig `json:"linear" yaml:"linear"`
// Internal fields
IsTest bool `json:"is_test" yaml:"is_test"`
}
// ServerConfig represents server configuration
type ServerConfig struct {
Port string `json:"port" yaml:"port"`
WorkerPort string `json:"worker_port" yaml:"worker_port"`
AdminUsername string `json:"admin_username" yaml:"admin_username"`
AdminPassword string `json:"admin_password" yaml:"admin_password"`
SessionSecret string `json:"session_secret" yaml:"session_secret"`
Debug bool `json:"debug" yaml:"debug"`
LogLevel string `json:"log_level" yaml:"log_level"`
WorkerBaseURL string `json:"worker_base_url" yaml:"worker_base_url"`
WorkerInternalURL string `json:"worker_internal_url" yaml:"worker_internal_url"`
BackendBaseURL string `json:"backend_base_url" yaml:"backend_base_url"`
AppBaseURL string `json:"app_base_url" yaml:"app_base_url"`
MaxAIConcurrent int `json:"max_ai_concurrent" yaml:"max_ai_concurrent"`
MaxAIPerUser int `json:"max_ai_per_user" yaml:"max_ai_per_user"`
CORSOrigins []string `json:"cors_origins" yaml:"cors_origins"`
QuestionRefillThreshold int `json:"question_refill_threshold" yaml:"question_refill_threshold"`
// DailyFreshQuestionRatio controls the minimum fraction of fresh (never-seen)
// questions to aim for when refilling question pools (0.0 - 1.0). Example: 0.35
// means at least 35% fresh questions when refilling.
DailyFreshQuestionRatio float64 `json:"daily_fresh_question_ratio" yaml:"daily_fresh_question_ratio"`
MaxHistory int `json:"max_history" yaml:"max_history"`
MaxActivityLogs int `json:"max_activity_logs" yaml:"max_activity_logs"`
DailyRepeatAvoidDays int `json:"daily_repeat_avoid_days" yaml:"daily_repeat_avoid_days"`
// DailyHorizonDays controls how many days ahead the worker will assign
// daily questions (e.g. 0 = today only, 1 = today+1, ...). If unset or
// <= 0 the worker will fall back to the DAILY_HORIZON_DAYS environment
// variable (default 1).
DailyHorizonDays int `json:"daily_horizon_days" yaml:"daily_horizon_days"`
}
// GetLanguages returns a slice of all supported languages (derived from language_levels keys)
func (c *Config) GetLanguages() []string {
if c.LanguageLevels == nil {
return []string{}
}
languages := make([]string, 0, len(c.LanguageLevels))
for lang := range c.LanguageLevels {
languages = append(languages, lang)
}
sort.Strings(languages)
return languages
}
// GetLanguageInfoList returns a slice of language info objects with code and name
func (c *Config) GetLanguageInfoList() []LanguageInfo {
if c.LanguageLevels == nil {
return []LanguageInfo{}
}
languageInfos := make([]LanguageInfo, 0, len(c.LanguageLevels))
for langName, langConfig := range c.LanguageLevels {
var ttsLocale, ttsVoice *string
if langConfig.TtsLocale != "" {
ttsLocale = &langConfig.TtsLocale
}
if langConfig.TtsVoice != "" {
ttsVoice = &langConfig.TtsVoice
}
languageInfos = append(languageInfos, LanguageInfo{
Code: langConfig.Code,
Name: langName,
TtsLocale: ttsLocale,
TtsVoice: ttsVoice,
})
}
// Sort by name for consistent ordering
sort.Slice(languageInfos, func(i, j int) bool {
return languageInfos[i].Name < languageInfos[j].Name
})
return languageInfos
}
// GetLevelsForLanguage returns the levels for a specific language
func (c *Config) GetLevelsForLanguage(language string) []string {
if c.LanguageLevels == nil {
return []string{}
}
// First try to look up by language name directly
if langConfig, exists := c.LanguageLevels[language]; exists {
return langConfig.Levels
}
// If not found by name, try to find by language code
for _, langConfig := range c.LanguageLevels {
if langConfig.Code == language {
return langConfig.Levels
}
}
return []string{}
}
// GetLevelDescriptionsForLanguage returns the level descriptions for a specific language
func (c *Config) GetLevelDescriptionsForLanguage(language string) map[string]string {
if c.LanguageLevels == nil {
return map[string]string{}
}
// First try to look up by language name directly
if langConfig, exists := c.LanguageLevels[language]; exists {
return langConfig.Descriptions
}
// If not found by name, try to find by language code
for _, langConfig := range c.LanguageLevels {
if langConfig.Code == language {
return langConfig.Descriptions
}
}
return map[string]string{}
}
// GetAllLevels returns all unique levels across all languages
func (c *Config) GetAllLevels() []string {
if c.LanguageLevels == nil {
return []string{}
}
levelSet := make(map[string]bool)
for _, langConfig := range c.LanguageLevels {
for _, level := range langConfig.Levels {
levelSet[level] = true
}
}
levels := make([]string, 0, len(levelSet))
for level := range levelSet {
levels = append(levels, level)
}
sort.Strings(levels)
return levels
}
// GetAllLevelDescriptions returns all unique level descriptions across all languages
func (c *Config) GetAllLevelDescriptions() map[string]string {
if c.LanguageLevels == nil {
return map[string]string{}
}
descriptions := make(map[string]string)
for _, langConfig := range c.LanguageLevels {
for level, description := range langConfig.Descriptions {
descriptions[level] = description
}
}
return descriptions
}
// Languages returns all supported languages
func (c *Config) Languages() []string {
return c.GetLanguages()
}
// Levels returns all unique levels
func (c *Config) Levels() []string {
return c.GetAllLevels()
}
// LevelDescriptions returns all unique level descriptions
func (c *Config) LevelDescriptions() map[string]string {
return c.GetAllLevelDescriptions()
}
// IsSignupDisabled returns whether signups are disabled based on configuration
func (c *Config) IsSignupDisabled() bool {
if c.System == nil {
return false // Default to enabled if no config
}
return c.System.Auth.SignupsDisabled
}
// IsEmailAllowed checks if an email is allowed for OAuth signup override
func (c *Config) IsEmailAllowed(email string) bool {
if c.System == nil || c.System.Auth.AllowedEmails == nil {
return false
}
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
for _, allowedEmail := range c.System.Auth.AllowedEmails {
if strings.ToLower(strings.TrimSpace(allowedEmail)) == normalizedEmail {
return true
}
}
return false
}
// IsDomainAllowed checks if a domain is allowed for OAuth signup override
func (c *Config) IsDomainAllowed(domain string) bool {
if c.System == nil || c.System.Auth.AllowedDomains == nil {
return false
}
normalizedDomain := strings.ToLower(strings.TrimSpace(domain))
for _, allowedDomain := range c.System.Auth.AllowedDomains {
if strings.ToLower(strings.TrimSpace(allowedDomain)) == normalizedDomain {
return true
}
}
return false
}
// IsOAuthSignupAllowed checks if OAuth signup is allowed for a given email
func (c *Config) IsOAuthSignupAllowed(email string) bool {
if c.System == nil {
return false
}
// If signups are not disabled, OAuth signup is always allowed
if !c.System.Auth.SignupsDisabled {
return true
}
// If signups are disabled, check whitelist
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
// Use the shared email validation function
if !contextutils.IsValidEmail(normalizedEmail) {
return false
}
// Check if email is directly whitelisted
if c.IsEmailAllowed(normalizedEmail) {
return true
}
// Extract domain from email and check if domain is whitelisted
parts := strings.Split(normalizedEmail, "@")
domain := parts[1]
return c.IsDomainAllowed(domain)
}
// OpenTelemetryConfig holds all OpenTelemetry-related configuration
type OpenTelemetryConfig struct {
Endpoint string `json:"endpoint" yaml:"endpoint"` // Default: "http://localhost:4317"
Protocol string `json:"protocol" yaml:"protocol"` // "grpc" or "http", default: "grpc"
Insecure bool `json:"insecure" yaml:"insecure"` // Default: true (for localhost)
Headers map[string]string `json:"headers" yaml:"headers"` // For authenticated endpoints
ServiceName string `json:"service_name" yaml:"service_name"` // Default: "quiz-backend" or "quiz-worker"
ServiceVersion string `json:"service_version" yaml:"service_version"` // From version package
EnableTracing bool `json:"enable_tracing" yaml:"enable_tracing"` // Default: true
EnableMetrics bool `json:"enable_metrics" yaml:"enable_metrics"` // Default: true
EnableLogging bool `json:"enable_logging" yaml:"enable_logging"` // Default: true (future)
SamplingRate float64 `json:"sampling_rate" yaml:"sampling_rate"` // Default: 1.0 (100%)
}
// DatabaseConfig represents database configuration
type DatabaseConfig struct {
URL string `json:"url" yaml:"url"`
MaxOpenConns int `json:"max_open_conns" yaml:"max_open_conns"` // Maximum number of open connections to the database
MaxIdleConns int `json:"max_idle_conns" yaml:"max_idle_conns"` // Maximum number of idle connections in the pool
ConnMaxLifetime time.Duration `json:"conn_max_lifetime" yaml:"conn_max_lifetime"` // Maximum amount of time a connection may be reused
}
// EmailConfig represents email/SMTP configuration
type EmailConfig struct {
SMTP SMTPConfig `json:"smtp" yaml:"smtp"`
DailyReminder DailyReminderConfig `json:"daily_reminder" yaml:"daily_reminder"`
Enabled bool `json:"enabled" yaml:"enabled"`
}
// SMTPConfig represents SMTP server configuration
type SMTPConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Username string `json:"username" yaml:"username"`
Password string `json:"password" yaml:"password"`
FromAddress string `json:"from_address" yaml:"from_address"`
FromName string `json:"from_name" yaml:"from_name"`
}
// DailyReminderConfig represents daily reminder email configuration
type DailyReminderConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Hour int `json:"hour" yaml:"hour"` // Hour of day to send (0-23)
}
// StorySectionLengthsConfig represents section length configuration by proficiency level
type StorySectionLengthsConfig struct {
Beginner map[string]int `json:"beginner" yaml:"beginner"`
Elementary map[string]int `json:"elementary" yaml:"elementary"`
Intermediate map[string]int `json:"intermediate" yaml:"intermediate"`
UpperIntermediate map[string]int `json:"upper_intermediate" yaml:"upper_intermediate"`
Advanced map[string]int `json:"advanced" yaml:"advanced"`
Proficient map[string]int `json:"proficient" yaml:"proficient"`
Overrides map[string]map[string]map[string]int `json:"overrides" yaml:"overrides"`
}
// StoryConfig represents story mode configuration
type StoryConfig struct {
MaxArchivedPerUser int `json:"max_archived_per_user" yaml:"max_archived_per_user"`
GenerationEnabled bool `json:"generation_enabled" yaml:"generation_enabled"`
EngagementBasedGeneration bool `json:"engagement_based_generation" yaml:"engagement_based_generation"`
SectionLengths StorySectionLengthsConfig `json:"section_lengths" yaml:"section_lengths"`
QuestionsPerSection int `json:"questions_per_section" yaml:"questions_per_section"`
MaxExtraGenerationsPerDay int `json:"max_extra_generations_per_day" yaml:"max_extra_generations_per_day"`
MaxWorkerGenerationsPerDay int `json:"max_worker_generations_per_day" yaml:"max_worker_generations_per_day"`
}
// TranslationConfig represents translation service configuration
type TranslationConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
DefaultProvider string `json:"default_provider" yaml:"default_provider"`
Providers map[string]TranslationProviderConfig `json:"providers" yaml:"providers"`
Quota TranslationQuotaConfig `json:"quota" yaml:"quota"`
}
// TranslationProviderConfig represents a translation provider configuration
type TranslationProviderConfig struct {
Name string `json:"name" yaml:"name"`
Code string `json:"code" yaml:"code"`
APIKey string `json:"api_key" yaml:"api_key"`
BaseURL string `json:"base_url" yaml:"base_url"`
APIEndpoint string `json:"api_endpoint" yaml:"api_endpoint"`
MaxTextLength int `json:"max_text_length" yaml:"max_text_length"`
}
// TranslationQuotaConfig represents quota configuration for translation services
type TranslationQuotaConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
// Monthly character quotas per provider
GoogleMonthlyQuota int64 `json:"google_monthly_quota" yaml:"google_monthly_quota"`
// Default monthly quota for new providers (in characters)
DefaultMonthlyQuota int64 `json:"default_monthly_quota" yaml:"default_monthly_quota"`
}
// LinearConfig represents Linear integration configuration
type LinearConfig struct {
APIKey string `json:"api_key" yaml:"api_key"` // API key from LINEAR_API_KEY env var
TeamID string `json:"team_id" yaml:"team_id"` // Team ID, override via LINEAR_TEAM_ID
ProjectID string `json:"project_id" yaml:"project_id"` // Project ID, override via LINEAR_PROJECT_ID
DefaultLabels []string `json:"default_labels" yaml:"default_labels"` // Optional default labels
DefaultState string `json:"default_state" yaml:"default_state"` // Optional default state (e.g., "Todo")
Enabled bool `json:"enabled" yaml:"enabled"` // Feature flag
}
// NewConfig loads configuration from YAML file first, then overrides with environment variables
func NewConfig() (result0 *Config, err error) {
// Load config from YAML file
config, err := loadConfigWithOverrides()
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to load config: %v", err)
}
// Override with environment variables
config.overrideFromEnv()
return config, nil
}
// overrideFromEnv overrides config values with environment variables using reflection
func (c *Config) overrideFromEnv() {
overrideStructFromEnv(c)
}
// overrideStructFromEnv recursively overrides struct fields with environment variables
func overrideStructFromEnv(v interface{}) {
overrideStructFromEnvWithPrefix(v, "")
}
// overrideStructFromEnvWithPrefix recursively overrides struct fields with environment variables
func overrideStructFromEnvWithPrefix(v interface{}, prefix string) {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// Skip unexported fields
if !field.CanSet() {
continue
}
// Get the yaml tag for the field
yamlTag := fieldType.Tag.Get("yaml")
if yamlTag == "" || yamlTag == "-" {
continue
}
// Convert yaml tag to environment variable name
envKey := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if prefix != "" {
envKey = prefix + "_" + envKey
}
switch field.Kind() {
case reflect.String:
if envVal := os.Getenv(envKey); envVal != "" {
field.SetString(envVal)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if envVal := os.Getenv(envKey); envVal != "" {
if intVal, err := strconv.ParseInt(envVal, 10, 64); err == nil {
field.SetInt(intVal)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if envVal := os.Getenv(envKey); envVal != "" {
if uintVal, err := strconv.ParseUint(envVal, 10, 64); err == nil {
field.SetUint(uintVal)
}
}
case reflect.Float32, reflect.Float64:
if envVal := os.Getenv(envKey); envVal != "" {
if floatVal, err := strconv.ParseFloat(envVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.Bool:
if envVal := os.Getenv(envKey); envVal != "" {
if boolVal, err := strconv.ParseBool(envVal); err == nil {
field.SetBool(boolVal)
}
}
case reflect.Slice:
if envVal := os.Getenv(envKey); envVal != "" {
// Handle string slices (like CORS_ORIGINS)
if field.Type().Elem().Kind() == reflect.String {
slice := strings.Split(envVal, ",")
field.Set(reflect.ValueOf(slice))
}
}
case reflect.Map:
// Handle map fields with string keys and struct values
if field.Type().Key().Kind() == reflect.String && field.Type().Elem().Kind() == reflect.Struct {
handleMapFieldOverrides(field, yamlTag, prefix)
}
case reflect.Struct:
// Recursively process nested structs with the field name as prefix
if field.CanAddr() {
fieldPrefix := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if prefix != "" {
fieldPrefix = prefix + "_" + fieldPrefix
}
overrideStructFromEnvWithPrefix(field.Addr().Interface(), fieldPrefix)
}
case reflect.Ptr:
// Handle pointer to struct
if !field.IsNil() && field.Elem().Kind() == reflect.Struct {
fieldPrefix := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if prefix != "" {
fieldPrefix = prefix + "_" + fieldPrefix
}
overrideStructFromEnvWithPrefix(field.Interface(), fieldPrefix)
}
}
}
}
// handleMapFieldOverrides handles environment variable overrides for map fields with string keys and struct values
func handleMapFieldOverrides(field reflect.Value, yamlTag, parentPrefix string) {
if !field.CanSet() || field.Type().Key().Kind() != reflect.String {
return
}
// Build the prefix for environment variables
mapPrefix := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if parentPrefix != "" {
mapPrefix = parentPrefix + "_" + mapPrefix
}
// Iterate through all keys in the map and look for corresponding environment variables
for _, key := range field.MapKeys() {
keyName := key.String()
keyVal := field.MapIndex(key)
if keyVal.IsValid() && keyVal.Kind() == reflect.Struct {
// Create a new struct with potential overrides
newStruct := createStructWithOverrides(keyVal, keyName, mapPrefix)
if newStruct.IsValid() {
field.SetMapIndex(key, newStruct)
}
}
}
}
// createStructWithOverrides creates a new struct with environment variable overrides applied
func createStructWithOverrides(originalStruct reflect.Value, keyName, mapPrefix string) reflect.Value {
if !originalStruct.IsValid() || originalStruct.Kind() != reflect.Struct {
return reflect.Value{}
}
structType := originalStruct.Type()
newStruct := reflect.New(structType).Elem()
updated := false
for i := 0; i < structType.NumField(); i++ {
fieldInfo := structType.Field(i)
origField := originalStruct.Field(i)
newField := newStruct.Field(i)
// Skip unexported fields
if !newField.CanSet() {
continue
}
// Get the yaml tag for the field
yamlTag := fieldInfo.Tag.Get("yaml")
if yamlTag == "" || yamlTag == "-" {
// Copy original value for fields without yaml tags
newField.Set(origField)
continue
}
// Convert yaml tag to environment variable name
envKey := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
envVarName := fmt.Sprintf("%s_%s_%s", mapPrefix, strings.ToUpper(keyName), envKey)
envVal := os.Getenv(envVarName)
if envVal != "" {
// Set the field value based on its type
setReflectValue(newField, envVal)
updated = true
} else {
// Copy the original value
newField.Set(origField)
}
}
if updated {
return newStruct
}
return reflect.Value{}
}
// setReflectValue sets a reflect.Value from a string environment variable
func setReflectValue(field reflect.Value, envVal string) {
if !field.CanSet() {
return
}
switch field.Kind() {
case reflect.String:
field.SetString(envVal)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if intVal, err := strconv.ParseInt(envVal, 10, 64); err == nil {
field.SetInt(intVal)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if uintVal, err := strconv.ParseUint(envVal, 10, 64); err == nil {
field.SetUint(uintVal)
}
case reflect.Float32, reflect.Float64:
if floatVal, err := strconv.ParseFloat(envVal, 64); err == nil {
field.SetFloat(floatVal)
}
case reflect.Bool:
if boolVal, err := strconv.ParseBool(envVal); err == nil {
field.SetBool(boolVal)
}
}
}
// loadConfigWithOverrides loads the config file with potential local overrides
func loadConfigWithOverrides() (result0 *Config, err error) {
// Try to load from environment variable first
if envPath := os.Getenv("QUIZ_CONFIG_FILE"); envPath != "" {
config, err := loadConfigFromFile(envPath)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to load config from %s: %v", envPath, err)
}
return config, nil
}
// If no environment variable is set, try default config.yaml
return loadConfigFromFile("config.yaml")
}
// loadConfigFromFile loads configuration from a specific file
func loadConfigFromFile(path string) (result0 *Config, err error) {
yamlFile, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var config Config
if err := yaml.Unmarshal(yamlFile, &config); err != nil {
return nil, err
}
return &config, nil
}
// Package database provides database connection and migration functionality.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"quizapp/internal/config"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
// Import PostgreSQL driver for database/sql
_ "github.com/lib/pq"
// Add golang-migrate imports
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres" // required for golang-migrate postgres driver
_ "github.com/golang-migrate/migrate/v4/source/file" // required for golang-migrate file source
// OpenTelemetry SQL instrumentation
"go.nhat.io/otelsql"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
)
// Manager handles database operations with proper logging
type Manager struct {
logger *observability.Logger
}
var (
otelDriverNameCache string
otelDriverOnce sync.Once
otelDriverErr error
)
// NewManager creates a new database manager with the provided logger
func NewManager(logger *observability.Logger) *Manager {
return &Manager{
logger: logger,
}
}
// ErrTableAlreadyExists is returned when trying to create a table that already exists
var ErrTableAlreadyExists = errors.New("table already exists")
// DefaultDatabaseConfig returns the default database configuration
func DefaultDatabaseConfig() config.DatabaseConfig {
config := config.DatabaseConfig{
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: config.DatabaseConnMaxLifetime,
}
// Check for TEST_DATABASE_URL first (for tests)
if testURL := os.Getenv("TEST_DATABASE_URL"); testURL != "" {
config.URL = testURL
}
return config
}
// InitDB initializes and returns a database connection with migrations
func (dm *Manager) InitDB(databaseURL string) (result0 *sql.DB, err error) {
dbName := extractDatabaseName(databaseURL)
_, span := observability.TraceDatabaseFunction(context.Background(), "InitDB",
attribute.String("db.url", databaseURL),
attribute.String("db.name", dbName),
attribute.String("db.system", "postgresql"),
attribute.Bool("migrations.enabled", true),
)
defer observability.FinishSpan(span, &err)
config := DefaultDatabaseConfig()
config.URL = databaseURL
return dm.InitDBWithConfig(config)
}
// InitDBWithConfig initializes and returns a database connection with migrations and custom config
func (dm *Manager) InitDBWithConfig(config config.DatabaseConfig) (result0 *sql.DB, err error) {
dbName := extractDatabaseName(config.URL)
_, span := observability.TraceDatabaseFunction(context.Background(), "InitDBWithConfig",
attribute.String("db.url", config.URL),
attribute.String("db.name", dbName),
attribute.String("db.system", "postgresql"),
attribute.Bool("migrations.enabled", true),
attribute.Int("db.max_open_conns", config.MaxOpenConns),
attribute.Int("db.max_idle_conns", config.MaxIdleConns),
attribute.String("db.conn_max_lifetime", config.ConnMaxLifetime.String()),
)
defer observability.FinishSpan(span, &err)
db, err := dm.InitDBWithoutMigrations(config)
if err != nil {
return nil, err
}
if err := dm.RunMigrations(db); err != nil {
return nil, err
}
return db, nil
}
// extractDatabaseName extracts the database name from a PostgreSQL connection string
func extractDatabaseName(databaseURL string) string {
// Try to parse as URL first
if u, err := url.Parse(databaseURL); err == nil && u.Path != "" {
// Remove leading slash and return the database name
dbName := strings.TrimPrefix(u.Path, "/")
if dbName != "" {
return dbName
}
}
// Fallback: try to extract from connection string format
// postgres://user:pass@host:port/dbname?sslmode=disable
if strings.Contains(databaseURL, "/") {
parts := strings.Split(databaseURL, "/")
if len(parts) > 1 {
// Get the last part and remove query parameters
dbPart := parts[len(parts)-1]
if idx := strings.Index(dbPart, "?"); idx != -1 {
return dbPart[:idx]
}
return dbPart
}
}
// Default fallback
return "quiz_db"
}
// InitDBWithoutMigrations initializes and returns a database connection without running migrations
func (dm *Manager) InitDBWithoutMigrations(config config.DatabaseConfig) (result0 *sql.DB, err error) {
// Extract database name for OpenTelemetry tracing
ctx, span := observability.TraceDatabaseFunction(context.Background(), "InitDBWithoutMigrations",
attribute.String("database.url", config.URL),
)
defer observability.FinishSpan(span, &err)
// Register OpenTelemetry SQL driver once per process and reuse the name
otelDriverOnce.Do(func() {
otelDriverNameCache, otelDriverErr = otelsql.Register("postgres",
otelsql.WithDatabaseName(extractDatabaseName(config.URL)),
otelsql.TraceQueryWithArgs(),
otelsql.WithSystem(semconv.DBSystemPostgreSQL),
otelsql.TraceRowsAffected(),
)
})
if otelDriverErr != nil {
return nil, contextutils.WrapError(otelDriverErr, "failed to register otelsql driver")
}
// Connect to database using the instrumented driver
db, err := sql.Open(otelDriverNameCache, config.URL)
if err != nil {
return nil, contextutils.WrapError(err, "failed to open database connection")
}
// Set connection pool settings
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
db.SetConnMaxLifetime(config.ConnMaxLifetime)
// Test the connection
if err := db.Ping(); err != nil {
if closeErr := db.Close(); closeErr != nil {
dm.logger.Error(ctx, "Failed to close database connection after ping failure", closeErr)
}
return nil, contextutils.WrapError(err, "failed to ping database")
}
dm.logger.Info(ctx, "Database connection established without migrations", map[string]interface{}{
"max_open_conns": config.MaxOpenConns,
"max_idle_conns": config.MaxIdleConns,
"conn_max_lifetime": config.ConnMaxLifetime,
})
return db, nil
}
// RunMigrations executes the application SQL schema and any pending migrations
func (dm *Manager) RunMigrations(db *sql.DB) (err error) {
_, span := observability.TraceDatabaseFunction(context.Background(), "RunMigrations",
attribute.String("db.system", "postgresql"),
attribute.String("migration.type", "application_schema"),
)
defer observability.FinishSpan(span, &err)
dm.logger.Info(context.Background(), "Starting database migrations...")
// Run the main application schema first
if err := dm.runApplicationSchema(db); err != nil {
return contextutils.WrapError(err, "failed to run application schema")
}
dm.logger.Info(context.Background(), "Application schema applied successfully")
// Run golang-migrate migrations if directory exists
if err := dm.runGolangMigrate(); err != nil {
return contextutils.WrapError(err, "failed to run golang-migrate migrations")
}
dm.logger.Info(context.Background(), "Database migrations completed successfully")
return nil
}
// runGolangMigrate runs migrations using golang-migrate from migrations
func (dm *Manager) runGolangMigrate() (err error) {
migrationsPath, err := dm.GetMigrationsPath()
if err != nil {
dm.logger.Error(context.Background(), "Could not find migrations path", err)
return err // HARD FAIL if migrations path is not set
}
_, span := observability.TraceDatabaseFunction(context.Background(), "runGolangMigrate",
attribute.String("db.system", "postgresql"),
attribute.String("migration.type", "golang_migrate"),
attribute.String("migration.path", migrationsPath),
)
defer observability.FinishSpan(span, &err)
if migrationsPath == "" {
err = errors.New("no golang-migrate migrations directory found")
dm.logger.Error(context.Background(), "No golang-migrate migrations directory found, hard fail!", err)
return err // HARD FAIL
}
// Check if migrations directory exists and has migration files
if _, statErr := os.Stat(migrationsPath); os.IsNotExist(statErr) {
dm.logger.Error(context.Background(), "Migrations directory does not exist", statErr)
err = statErr // HARD FAIL if directory does not exist
return err
}
// Check if there are any migration files in the directory
files, err := os.ReadDir(migrationsPath)
if err != nil {
dm.logger.Error(context.Background(), "Could not read migrations directory", err)
return err // HARD FAIL
}
// Check if there are any .up.sql files
hasMigrationFiles := false
migrationFileCount := 0
for _, file := range files {
if !file.IsDir() && strings.HasSuffix(file.Name(), ".up.sql") {
hasMigrationFiles = true
migrationFileCount++
}
}
span.SetAttributes(attribute.Int("migration.files.count", migrationFileCount))
if !hasMigrationFiles {
dm.logger.Info(context.Background(), fmt.Sprintf("No migration files found in %s. Skipping golang-migrate.", migrationsPath))
return nil
}
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
dbURL = os.Getenv("TEST_DATABASE_URL")
}
if dbURL == "" {
err = errors.New("database_url or test_database_url must be set for migrations")
return err
}
// Use file:// scheme with absolute path for golang-migrate
// Convert to file:// URL format - use absolute path
migrationSourceURL := "file://" + filepath.ToSlash(migrationsPath)
// Debug logging
dm.logger.Info(context.Background(), "Migration paths", map[string]interface{}{
"migrations_path": migrationsPath,
"source_url": migrationSourceURL,
"db_url": dbURL,
})
m, err := migrate.New(
migrationSourceURL,
dbURL,
)
if err != nil {
err = contextutils.WrapError(err, "failed to initialize golang-migrate")
return err
}
defer func() {
if _, closeErr := m.Close(); closeErr != nil {
dm.logger.Error(context.Background(), "Error closing migration", closeErr)
}
}()
err = m.Up()
if err != nil && err != migrate.ErrNoChange {
err = contextutils.WrapError(err, "golang-migrate up failed")
return err
}
if err == migrate.ErrNoChange {
dm.logger.Info(context.Background(), "No new golang-migrate migrations to apply.")
} else {
dm.logger.Info(context.Background(), "golang-migrate migrations applied successfully.")
}
return nil
}
// runApplicationSchema executes the main application schema.sql
func (dm *Manager) runApplicationSchema(db *sql.DB) (err error) {
schemaPath, err := dm.getSchemaPath()
if err != nil {
err = contextutils.WrapError(err, "failed to find schema file")
return err
}
_, span := observability.TraceDatabaseFunction(context.Background(), "runApplicationSchema",
attribute.String("db.system", "postgresql"),
attribute.String("migration.type", "application_schema"),
attribute.String("schema.path", schemaPath),
)
defer observability.FinishSpan(span, &err)
// Get the schema file path relative to the project root
schemaPath, err = dm.getSchemaPath()
if err != nil {
err = contextutils.WrapError(err, "failed to find schema file")
return err
}
// Read the schema file
schemaSQL, err := os.ReadFile(schemaPath)
if err != nil {
err = contextutils.WrapError(err, "failed to read schema file")
return err
}
span.SetAttributes(attribute.Int("schema.file.size", len(schemaSQL)))
// Parse SQL statements more carefully to handle comments and multi-line statements
statements := dm.parseSchemaStatements(string(schemaSQL))
span.SetAttributes(attribute.Int("schema.statements.count", len(statements)))
// Execute table creation statements first
var indexStatements []string
for _, statement := range statements {
statement = strings.TrimSpace(statement)
if statement == "" {
continue
}
// Separate index creation from table creation
if strings.HasPrefix(strings.ToUpper(statement), "CREATE INDEX") {
indexStatements = append(indexStatements, statement)
continue
}
_, execErr := db.Exec(statement)
if execErr != nil {
// For backwards compatibility, ignore table exists errors
if !dm.isTableExistsError(execErr) {
err = contextutils.WrapErrorf(execErr, "failed to execute schema statement: %s", statement)
return err
}
}
}
span.SetAttributes(attribute.Int("schema.index_statements.count", len(indexStatements)))
// Now execute index creation statements
for _, statement := range indexStatements {
_, execErr := db.Exec(statement)
if execErr != nil {
// For backwards compatibility, ignore index exists and column exists errors
if !dm.isTableExistsError(execErr) && !dm.isColumnExistsError(execErr) {
err = contextutils.WrapErrorf(execErr, "failed to execute index statement: %s", statement)
return err
}
}
}
return nil
}
// getSchemaPath finds the schema.sql file relative to the project root
func (dm *Manager) getSchemaPath() (result0 string, err error) {
_, span := observability.TraceDatabaseFunction(context.Background(), "getSchemaPath",
attribute.String("file.name", "schema.sql"),
)
defer observability.FinishSpan(span, &err)
// Start from the current directory and work up to find schema.sql
currentDir, err := os.Getwd()
if err != nil {
return "", err
}
span.SetAttributes(attribute.String("search.start_dir", currentDir))
for {
schemaPath := filepath.Join(currentDir, "schema.sql")
if _, statErr := os.Stat(schemaPath); statErr == nil {
span.SetAttributes(attribute.String("schema.found_path", schemaPath))
return schemaPath, nil
}
// Move up one directory
parentDir := filepath.Dir(currentDir)
if parentDir == currentDir {
// We've reached the root directory
span.SetAttributes(attribute.String("search.result", "not_found"))
err = contextutils.ErrorWithContextf("schema.sql not found in any parent directory")
return "", err
}
currentDir = parentDir
}
}
// parseSchemaStatements parses SQL statements from a schema file
func (dm *Manager) parseSchemaStatements(schemaSQL string) []string {
_, span := observability.TraceDatabaseFunction(context.Background(), "parseSchemaStatements",
attribute.Int("input.length", len(schemaSQL)),
)
defer span.End()
// Remove comments and normalize whitespace
lines := strings.Split(schemaSQL, "\n")
var cleanedLines []string
inComment := false
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip empty lines
if line == "" {
continue
}
// Handle multi-line comments
if strings.HasPrefix(line, "/*") {
inComment = true
continue
}
if strings.HasSuffix(line, "*/") {
inComment = false
continue
}
if inComment {
continue
}
// Skip single-line comments
if strings.HasPrefix(line, "--") {
continue
}
// Remove inline comments (comments that appear after SQL code)
if commentIndex := strings.Index(line, "--"); commentIndex != -1 {
line = strings.TrimSpace(line[:commentIndex])
}
cleanedLines = append(cleanedLines, line)
}
// Join lines and split by semicolon
cleanedSQL := strings.Join(cleanedLines, " ")
statements := strings.Split(cleanedSQL, ";")
var result []string
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt != "" {
result = append(result, stmt)
}
}
span.SetAttributes(attribute.Int("statements.parsed", len(result)))
return result
}
// isTableExistsError checks if the error is due to a table already existing
func (dm *Manager) isTableExistsError(err error) bool {
_, span := observability.TraceDatabaseFunction(context.Background(), "isTableExistsError")
defer span.End()
// Check for the sentinel error first
if errors.Is(err, ErrTableAlreadyExists) {
return true
}
// Fallback to string matching for backwards compatibility
return strings.Contains(err.Error(), "already exists")
}
// isColumnExistsError checks if the error is due to a column not existing (for index creation)
func (dm *Manager) isColumnExistsError(err error) bool {
_, span := observability.TraceDatabaseFunction(context.Background(), "isColumnExistsError")
defer span.End()
return strings.Contains(err.Error(), "column") && strings.Contains(err.Error(), "does not exist")
}
// GetMigrationsPath returns the path to the migrations directory
func (dm *Manager) GetMigrationsPath() (result0 string, err error) {
_, span := observability.TraceDatabaseFunction(context.Background(), "GetMigrationsPath",
attribute.String("migration.dir.name", "migrations"),
)
defer observability.FinishSpan(span, &err)
// Start from the current directory and work up to find migrations directory
currentDir, err := os.Getwd()
if err != nil {
return "", err
}
span.SetAttributes(attribute.String("search.start_dir", currentDir))
for {
migrationsPath := filepath.Join(currentDir, "migrations")
if _, statErr := os.Stat(migrationsPath); statErr == nil {
span.SetAttributes(attribute.String("migration.found_path", migrationsPath))
return migrationsPath, nil
}
// Move up one directory
parentDir := filepath.Dir(currentDir)
if parentDir == currentDir {
// We've reached the root directory
span.SetAttributes(attribute.String("search.result", "not_found"))
err = contextutils.ErrorWithContextf("migrations directory not found in any parent directory")
return "", err
}
currentDir = parentDir
}
}
// Package di provides dependency injection container for managing service lifecycle and dependencies.
package di
import (
"context"
"database/sql"
"sync"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
)
// ServiceContainerInterface defines the interface for service containers
type ServiceContainerInterface interface {
GetService(name string) (interface{}, error)
GetUserService() (services.UserServiceInterface, error)
GetQuestionService() (services.QuestionServiceInterface, error)
GetLearningService() (services.LearningServiceInterface, error)
GetAIService() (services.AIServiceInterface, error)
GetWorkerService() (services.WorkerServiceInterface, error)
GetDailyQuestionService() (services.DailyQuestionServiceInterface, error)
GetStoryService() (services.StoryServiceInterface, error)
GetOAuthService() (*services.OAuthService, error)
GetGenerationHintService() (services.GenerationHintServiceInterface, error)
GetConversationService() (services.ConversationServiceInterface, error)
GetEmailService() (services.EmailServiceInterface, error)
GetTranslationService() (services.TranslationServiceInterface, error)
GetSnippetsService() (services.SnippetsServiceInterface, error)
GetUsageStatsService() (services.UsageStatsServiceInterface, error)
GetWordOfTheDayService() (services.WordOfTheDayServiceInterface, error)
GetAuthAPIKeyService() (services.AuthAPIKeyServiceInterface, error)
GetDatabase() *sql.DB
GetConfig() *config.Config
GetLogger() *observability.Logger
Initialize(ctx context.Context) error
Shutdown(ctx context.Context) error
EnsureAdminUser(ctx context.Context) error
}
// ServiceContainer manages all service dependencies and lifecycle
type ServiceContainer struct {
cfg *config.Config
logger *observability.Logger
dbManager *database.Manager
db *sql.DB
services map[string]interface{}
mu sync.RWMutex
shutdownFuncs []func(context.Context) error
}
// NewServiceContainer creates a new dependency injection container
func NewServiceContainer(cfg *config.Config, logger *observability.Logger) *ServiceContainer {
return &ServiceContainer{
cfg: cfg,
logger: logger,
services: make(map[string]interface{}),
}
}
// Initialize sets up all services and their dependencies
func (sc *ServiceContainer) Initialize(ctx context.Context) error {
sc.mu.Lock()
defer sc.mu.Unlock()
// Initialize database
sc.dbManager = database.NewManager(sc.logger)
db, err := sc.dbManager.InitDBWithConfig(sc.cfg.Database)
if err != nil {
return contextutils.WrapErrorf(err, "failed to initialize database")
}
sc.db = db
sc.shutdownFuncs = append(sc.shutdownFuncs, func(_ context.Context) error {
return db.Close()
})
// Initialize core services
sc.initializeServices(ctx)
// Startup lifecycle services
if err := sc.startupServices(ctx); err != nil {
// Cleanup on failure
_ = sc.cleanup(ctx)
return contextutils.WrapErrorf(err, "failed to startup services")
}
return nil
}
// GetService retrieves a service by name with type assertion
func (sc *ServiceContainer) GetService(name string) (interface{}, error) {
sc.mu.RLock()
defer sc.mu.RUnlock()
service, exists := sc.services[name]
if !exists {
return nil, contextutils.ErrorWithContextf("service %s not found", name)
}
return service, nil
}
// GetServiceAs performs type-safe service retrieval
func GetServiceAs[T any](sc *ServiceContainer, name string) (T, error) {
var zero T
service, err := sc.GetService(name)
if err != nil {
return zero, err
}
typed, ok := service.(T)
if !ok {
return zero, contextutils.ErrorWithContextf("service %s is not of expected type %T", name, zero)
}
return typed, nil
}
// GetUserService returns the user service
func (sc *ServiceContainer) GetUserService() (services.UserServiceInterface, error) {
return GetServiceAs[services.UserServiceInterface](sc, "user")
}
// GetQuestionService returns the question service
func (sc *ServiceContainer) GetQuestionService() (services.QuestionServiceInterface, error) {
return GetServiceAs[services.QuestionServiceInterface](sc, "question")
}
// GetLearningService returns the learning service
func (sc *ServiceContainer) GetLearningService() (services.LearningServiceInterface, error) {
return GetServiceAs[services.LearningServiceInterface](sc, "learning")
}
// GetAIService returns the AI service
func (sc *ServiceContainer) GetAIService() (services.AIServiceInterface, error) {
return GetServiceAs[services.AIServiceInterface](sc, "ai")
}
// GetWorkerService returns the worker service
func (sc *ServiceContainer) GetWorkerService() (services.WorkerServiceInterface, error) {
return GetServiceAs[services.WorkerServiceInterface](sc, "worker")
}
// GetDailyQuestionService returns the daily question service
func (sc *ServiceContainer) GetDailyQuestionService() (services.DailyQuestionServiceInterface, error) {
return GetServiceAs[services.DailyQuestionServiceInterface](sc, "daily_question")
}
// GetStoryService returns the story service
func (sc *ServiceContainer) GetStoryService() (services.StoryServiceInterface, error) {
return GetServiceAs[services.StoryServiceInterface](sc, "story")
}
// GetOAuthService returns the OAuth service
func (sc *ServiceContainer) GetOAuthService() (*services.OAuthService, error) {
service, err := sc.GetService("oauth")
if err != nil {
return nil, err
}
oauthService, ok := service.(*services.OAuthService)
if !ok {
return nil, contextutils.ErrorWithContextf("oauth service has incorrect type")
}
return oauthService, nil
}
// GetGenerationHintService returns the generation hint service
func (sc *ServiceContainer) GetGenerationHintService() (services.GenerationHintServiceInterface, error) {
return GetServiceAs[services.GenerationHintServiceInterface](sc, "generation_hint")
}
// GetConversationService returns the conversation service
func (sc *ServiceContainer) GetConversationService() (services.ConversationServiceInterface, error) {
return GetServiceAs[services.ConversationServiceInterface](sc, "conversation")
}
// GetEmailService returns the email service
func (sc *ServiceContainer) GetEmailService() (services.EmailServiceInterface, error) {
return GetServiceAs[services.EmailServiceInterface](sc, "email")
}
// GetTranslationService returns the translation service
func (sc *ServiceContainer) GetTranslationService() (services.TranslationServiceInterface, error) {
return GetServiceAs[services.TranslationServiceInterface](sc, "translation")
}
// GetSnippetsService returns the snippets service
func (sc *ServiceContainer) GetSnippetsService() (services.SnippetsServiceInterface, error) {
return GetServiceAs[services.SnippetsServiceInterface](sc, "snippets")
}
// GetUsageStatsService returns the usage stats service
func (sc *ServiceContainer) GetUsageStatsService() (services.UsageStatsServiceInterface, error) {
return GetServiceAs[services.UsageStatsServiceInterface](sc, "usage_stats")
}
// GetWordOfTheDayService returns the word of the day service
func (sc *ServiceContainer) GetWordOfTheDayService() (services.WordOfTheDayServiceInterface, error) {
return GetServiceAs[services.WordOfTheDayServiceInterface](sc, "word_of_the_day")
}
// GetAuthAPIKeyService returns the auth API key service
func (sc *ServiceContainer) GetAuthAPIKeyService() (services.AuthAPIKeyServiceInterface, error) {
return GetServiceAs[services.AuthAPIKeyServiceInterface](sc, "auth_api_key")
}
// GetDatabase returns the database instance
func (sc *ServiceContainer) GetDatabase() *sql.DB {
return sc.db
}
// GetConfig returns the configuration
func (sc *ServiceContainer) GetConfig() *config.Config {
return sc.cfg
}
// GetLogger returns the logger
func (sc *ServiceContainer) GetLogger() *observability.Logger {
return sc.logger
}
// Shutdown gracefully shuts down all services
func (sc *ServiceContainer) Shutdown(ctx context.Context) error {
sc.mu.Lock()
defer sc.mu.Unlock()
return sc.cleanup(ctx)
}
// startupServices starts all services that implement the Lifecycle interface
func (sc *ServiceContainer) startupServices(ctx context.Context) error {
// Check each service to see if it implements Lifecycle interface
for name, service := range sc.services {
if lifecycleService, ok := service.(interface{ Startup(context.Context) error }); ok {
sc.logger.Info(ctx, "Starting service", map[string]interface{}{"service": name})
if err := lifecycleService.Startup(ctx); err != nil {
return contextutils.WrapErrorf(err, "failed to startup service %s", name)
}
sc.logger.Info(ctx, "Service started successfully", map[string]interface{}{"service": name})
}
}
return nil
}
// cleanup handles shutdown of all services
func (sc *ServiceContainer) cleanup(ctx context.Context) error {
var errors []error
// Shutdown lifecycle services first (in reverse order)
for name := range sc.services {
if lifecycleService, ok := sc.services[name].(interface{ Shutdown(context.Context) error }); ok {
sc.logger.Info(ctx, "Shutting down service", map[string]interface{}{"service": name})
if err := lifecycleService.Shutdown(ctx); err != nil {
sc.logger.Error(ctx, "Failed to shutdown service", err, map[string]interface{}{"service": name})
errors = append(errors, contextutils.WrapErrorf(err, "service %s shutdown failed", name))
} else {
sc.logger.Info(ctx, "Service shutdown successfully", map[string]interface{}{"service": name})
}
}
}
// Shutdown services in reverse order of initialization
for i := len(sc.shutdownFuncs) - 1; i >= 0; i-- {
if err := sc.shutdownFuncs[i](ctx); err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
return contextutils.ErrorWithContextf("shutdown errors: %v", errors)
}
return nil
}
// initializeServices sets up all service dependencies
func (sc *ServiceContainer) initializeServices(_ context.Context) {
// Core services that don't depend on other services
userService := services.NewUserServiceWithLogger(sc.db, sc.cfg, sc.logger)
sc.services["user"] = userService
// Learning service depends on user service
learningService := services.NewLearningServiceWithLogger(sc.db, sc.cfg, sc.logger)
sc.services["learning"] = learningService
// Question service depends on learning service
questionService := services.NewQuestionServiceWithLogger(sc.db, learningService, sc.cfg, sc.logger)
sc.services["question"] = questionService
// Daily question service depends on question and learning services
dailyQuestionService := services.NewDailyQuestionService(sc.db, sc.logger, questionService, learningService)
sc.services["daily_question"] = dailyQuestionService
// Story service
storyService := services.NewStoryService(sc.db, sc.cfg, sc.logger)
sc.services["story"] = storyService
// Worker service
workerService := services.NewWorkerServiceWithLogger(sc.db, sc.logger)
sc.services["worker"] = workerService
// Generation hint service
generationHintService := services.NewGenerationHintService(sc.db, sc.logger)
sc.services["generation_hint"] = generationHintService
// OAuth service
oauthService := services.NewOAuthServiceWithLogger(sc.cfg, sc.logger)
sc.services["oauth"] = oauthService
// Conversation service
conversationService := services.NewConversationService(sc.db)
sc.services["conversation"] = conversationService
// Email service (use concrete implementation with DB to satisfy EmailServiceInterface)
emailService := services.NewEmailServiceWithDB(sc.cfg, sc.logger, sc.db)
sc.services["email"] = emailService
// Usage stats service
usageStatsService := services.NewUsageStatsService(sc.cfg, sc.db, sc.logger)
sc.services["usage_stats"] = usageStatsService
// AI service (depends on usage stats service)
aiService := services.NewAIService(sc.cfg, sc.logger, usageStatsService)
sc.services["ai"] = aiService
// Translation cache repository
translationCacheRepo := services.NewTranslationCacheRepository(sc.db, sc.logger)
sc.services["translation_cache"] = translationCacheRepo
// Translation service (depends on usage stats service and translation cache repository)
translationService := services.NewTranslationService(sc.cfg, usageStatsService, translationCacheRepo, sc.logger)
sc.services["translation"] = translationService
// Initialize snippets service
snippetsService := services.NewSnippetsService(sc.db, sc.cfg, sc.logger)
sc.services["snippets"] = snippetsService
// Initialize word of the day service
wordOfTheDayService := services.NewWordOfTheDayService(sc.db, sc.logger)
sc.services["word_of_the_day"] = wordOfTheDayService
// Initialize auth API key service
authAPIKeyService := services.NewAuthAPIKeyService(sc.db, sc.logger)
sc.services["auth_api_key"] = authAPIKeyService
// Register shutdown functions
sc.shutdownFuncs = append(sc.shutdownFuncs,
func(_ context.Context) error { return nil }, // placeholder for future service shutdowns
)
}
// EnsureAdminUser creates the admin user if it doesn't exist
func (sc *ServiceContainer) EnsureAdminUser(ctx context.Context) error {
userService, err := sc.GetUserService()
if err != nil {
return contextutils.WrapErrorf(err, "failed to get user service")
}
return userService.EnsureAdminUserExists(ctx, sc.cfg.Server.AdminUsername, sc.cfg.Server.AdminPassword)
}
// Package handlers provides HTTP request handlers for the quiz application API.
package handlers
import (
"context"
"database/sql"
"encoding/json"
"errors"
"html/template"
"math"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// AdminHandler handles administrative HTTP requests and dashboard functionality
type AdminHandler struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
config *config.Config
templates *template.Template
learningService services.LearningServiceInterface
workerService services.WorkerServiceInterface
logger *observability.Logger
storyService services.StoryServiceInterface
usageStatsSvc services.UsageStatsServiceInterface
}
// NewAdminHandlerWithLogger creates a new AdminHandler with the provided services and logger.
func NewAdminHandlerWithLogger(userService services.UserServiceInterface, questionService services.QuestionServiceInterface, aiService services.AIServiceInterface, cfg *config.Config, learningService services.LearningServiceInterface, workerService services.WorkerServiceInterface, logger *observability.Logger, usageStatsSvc services.UsageStatsServiceInterface) *AdminHandler {
return &AdminHandler{
userService: userService,
questionService: questionService,
aiService: aiService,
config: cfg,
templates: nil,
learningService: learningService,
workerService: workerService,
logger: logger,
usageStatsSvc: usageStatsSvc,
}
}
// GetBackendAdminData returns the backend administration data as JSON
func (h *AdminHandler) GetBackendAdminData(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_backend_admin_data")
defer observability.FinishSpan(span, nil)
// Get all users for aggregate statistics
users, err := h.userService.GetAllUsers(ctx)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
HandleAppError(c, contextutils.WrapError(err, "failed to get users"))
return
}
// Calculate aggregate user statistics
userStats := calculateUserAggregateStats(ctx, users, h.learningService, h.logger)
// Get question statistics
questionStats, err := h.questionService.GetDetailedQuestionStats(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get question stats", map[string]interface{}{"error": err.Error()})
questionStats = make(map[string]interface{})
}
// Get worker health if available
var workerHealth map[string]interface{}
if h.workerService != nil {
workerHealth, err = h.workerService.GetWorkerHealth(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get worker health", map[string]interface{}{"error": err.Error()})
workerHealth = map[string]interface{}{
"error": "Failed to get worker health",
}
}
}
// Get AI concurrency stats
aiStatsStruct := h.aiService.GetConcurrencyStats()
aiConcurrencyStats := map[string]interface{}{
"active_requests": aiStatsStruct.ActiveRequests,
"max_concurrent": aiStatsStruct.MaxConcurrent,
"queued_requests": aiStatsStruct.QueuedRequests,
"total_requests": aiStatsStruct.TotalRequests,
"user_active_count": aiStatsStruct.UserActiveCount,
"max_per_user": aiStatsStruct.MaxPerUser,
}
data := gin.H{
"user_stats": userStats,
"question_stats": questionStats,
"worker_health": workerHealth,
"ai_concurrency_stats": aiConcurrencyStats,
"worker_port": h.config.Server.WorkerPort,
"worker_base_url": h.config.Server.WorkerBaseURL,
}
c.JSON(http.StatusOK, data)
}
// GetBackendAdminPage renders the backend administration dashboard
func (h *AdminHandler) GetBackendAdminPage(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_backend_admin_page")
defer observability.FinishSpan(span, nil)
// Get all users with progress and question stats
users, err := h.userService.GetAllUsers(ctx)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
HandleAppError(c, contextutils.WrapError(err, "failed to get users"))
return
}
type UserWithProgress struct {
User models.User
Progress *models.UserProgress
QuestionStats *services.UserQuestionStats
UserQuestionCounts map[string]interface{}
}
var usersWithProgress []UserWithProgress
for _, user := range users {
progress, err := h.learningService.GetUserProgress(ctx, user.ID)
if err != nil {
h.logger.Warn(ctx, "Failed to get progress for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
progress = &models.UserProgress{
CurrentLevel: "A1",
TotalQuestions: 0,
CorrectAnswers: 0,
AccuracyRate: 0,
}
}
questionStats, err := h.learningService.GetUserQuestionStats(ctx, user.ID)
if err != nil {
h.logger.Warn(ctx, "Failed to get question stats for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
questionStats = &services.UserQuestionStats{
UserID: user.ID,
TotalAnswered: 0,
}
}
// Get per-user question counts by type and level
userQuestionCounts := make(map[string]interface{})
// Use the available stats from UserQuestionStats
if questionStats != nil {
userQuestionCounts["total_answered"] = questionStats.TotalAnswered
userQuestionCounts["answered_by_type"] = questionStats.AnsweredByType
userQuestionCounts["answered_by_level"] = questionStats.AnsweredByLevel
userQuestionCounts["accuracy_by_type"] = questionStats.AccuracyByType
userQuestionCounts["accuracy_by_level"] = questionStats.AccuracyByLevel
userQuestionCounts["available_by_type"] = questionStats.AvailableByType
userQuestionCounts["available_by_level"] = questionStats.AvailableByLevel
}
usersWithProgress = append(usersWithProgress, UserWithProgress{
User: user,
Progress: progress,
QuestionStats: questionStats,
UserQuestionCounts: userQuestionCounts,
})
}
// Get question statistics
questionStats, err := h.questionService.GetDetailedQuestionStats(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get question stats", map[string]interface{}{"error": err.Error()})
questionStats = make(map[string]interface{})
}
// Get worker health if available
var workerHealth map[string]interface{}
if h.workerService != nil {
workerHealth, err = h.workerService.GetWorkerHealth(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get worker health", map[string]interface{}{"error": err.Error()})
workerHealth = map[string]interface{}{
"error": "Failed to get worker health",
}
}
}
// Get AI concurrency stats
aiStatsStruct := h.aiService.GetConcurrencyStats()
aiConcurrencyStats := map[string]interface{}{
"active_requests": aiStatsStruct.ActiveRequests,
"max_concurrent": aiStatsStruct.MaxConcurrent,
"queued_requests": aiStatsStruct.QueuedRequests,
"total_requests": aiStatsStruct.TotalRequests,
"user_active_count": aiStatsStruct.UserActiveCount,
"max_per_user": aiStatsStruct.MaxPerUser,
}
data := gin.H{
"Title": "Backend Administration",
"Users": usersWithProgress,
"QuestionStats": questionStats,
"WorkerHealth": workerHealth,
"AIConcurrencyStats": aiConcurrencyStats,
"IsBackend": true,
"WorkerPort": h.config.Server.WorkerPort,
"CurrentPage": "backend_admin",
"WorkerBaseURL": h.config.Server.WorkerBaseURL,
}
// Try to render template, fallback to JSON if template fails
if h.templates != nil {
// Add no-cache headers
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Header("Pragma", "no-cache")
c.Header("Expires", "0")
if err := h.templates.ExecuteTemplate(c.Writer, "backend_admin.html", data); err != nil {
h.logger.Error(ctx, "Template execution failed", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to render template"))
return
}
} else {
c.JSON(http.StatusOK, data)
}
}
// UserData represents user information combined with their progress data
type UserData struct {
User models.User
Progress *models.UserProgress
}
// UserDataWithQuestions represents user information with questions and responses
type UserDataWithQuestions struct {
User models.User
Progress *models.UserProgress
QuestionStats *services.UserQuestionStats
TotalQuestions int
TotalResponses int
RecentQuestions []string
Questions []*services.QuestionWithStats // Actual question objects with stats
}
// ReportedQuestionsData represents the structure for reported questions page data
type ReportedQuestionsData struct {
Users []UserDataWithQuestions
ReportedQuestions []*services.ReportedQuestionWithUser
}
// ShowDatazPage - Removed: Use frontend admin interface instead
// MarkQuestionAsFixed marks a reported question as fixed and puts it back in rotation
func (h *AdminHandler) MarkQuestionAsFixed(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
if err := h.questionService.MarkQuestionAsFixed(c.Request.Context(), questionID); err != nil {
h.logger.Error(c.Request.Context(), "Failed to mark question as fixed", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as fixed"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Question marked as fixed successfully"})
}
// UpdateQuestion updates a question's content, correct answer, and explanation
func (h *AdminHandler) UpdateQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var req struct {
Content map[string]interface{} `json:"content" binding:"required"`
CorrectAnswer int `json:"correct_answer" binding:"gte=0,lte=3"`
Explanation string `json:"explanation" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
// Sanitize incoming content to avoid nested `content.content` and duplicated fields.
content := req.Content
for {
if inner, ok := content["content"]; ok {
if innerMap, ok2 := inner.(map[string]interface{}); ok2 {
content = innerMap
continue
}
}
break
}
// Remove duplicate top-level keys from the content payload if present.
// Defensive cleanup while migrating to strict OpenAPI validation.
delete(content, "correct_answer")
delete(content, "explanation")
delete(content, "change_reason")
// Ensure options is not nil (convert null -> empty slice)
if opts, exists := content["options"]; !exists || opts == nil {
content["options"] = []string{}
}
if err := h.questionService.UpdateQuestion(c.Request.Context(), questionID, content, req.CorrectAnswer, req.Explanation); err != nil {
h.logger.Error(c.Request.Context(), "Failed to update question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update question"))
return
}
// If requested, mark the question as fixed and clear reports
if strings.ToLower(c.Query("mark_fixed")) == "true" {
ctx := c.Request.Context()
// Mark as fixed (sets status to active)
if err := h.questionService.MarkQuestionAsFixed(ctx, questionID); err != nil {
h.logger.Error(ctx, "Failed to mark question as fixed after update", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as fixed"))
return
}
// Clear question reports
db := h.questionService.DB()
if _, err := db.ExecContext(ctx, `DELETE FROM question_reports WHERE question_id = $1`, questionID); err != nil {
h.logger.Warn(ctx, "Failed to clear question reports", map[string]interface{}{"question_id": questionID, "error": err.Error()})
}
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Question updated successfully"})
}
// FixQuestionWithAI uses AI to suggest fixes for a problematic question
func (h *AdminHandler) FixQuestionWithAI(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Get the original question
question, err := h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if errors.Is(err, sql.ErrNoRows) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get question"))
return
}
// Find reporter(s) and choose a configured AI provider/model from the reporting user(s)
ctx := c.Request.Context()
db := h.questionService.DB()
rows, err := db.QueryContext(ctx, `SELECT u.id, u.username, u.ai_provider, u.ai_model, qr.report_reason FROM question_reports qr JOIN users u ON qr.reported_by_user_id = u.id WHERE qr.question_id = $1 ORDER BY qr.created_at ASC`, questionID)
if err != nil {
h.logger.Error(ctx, "Failed to query question reports", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to get report details"))
return
}
if err := rows.Err(); err != nil {
h.logger.Warn(ctx, "rows iteration error before defer", map[string]interface{}{"error": err.Error(), "question_id": questionID})
}
defer func() {
if err := rows.Close(); err != nil {
h.logger.Warn(ctx, "Failed to close report rows", map[string]interface{}{"error": err.Error(), "question_id": questionID})
}
}()
var reporterID int
var reporterUsername string
var reporterProvider sql.NullString
var reporterModel sql.NullString
var singleReason sql.NullString
foundProvider := false
for rows.Next() {
var uid int
var uname string
var prov sql.NullString
var mod sql.NullString
var reason sql.NullString
if err := rows.Scan(&uid, &uname, &prov, &mod, &reason); err != nil {
h.logger.Warn(ctx, "Failed to scan report row", map[string]interface{}{"error": err.Error(), "question_id": questionID})
continue
}
// Prefer the first reporter that has an AI provider+model configured
if prov.Valid && prov.String != "" && mod.Valid && mod.String != "" {
reporterID = uid
reporterUsername = uname
reporterProvider = prov
reporterModel = mod
singleReason = reason
foundProvider = true
break
}
// Keep the first reporter as fallback (no provider)
if reporterID == 0 {
reporterID = uid
reporterUsername = uname
reporterProvider = prov
reporterModel = mod
singleReason = reason
}
}
if !foundProvider {
// If no reporting user has AI configured, fall back to admin user's AI settings or global default provider
h.logger.Info(ctx, "No reporting user has AI configured; attempting fallback to admin or global provider", map[string]interface{}{"question_id": questionID})
// Try to get current admin user from context/session
var adminUserID int
if uid, err := GetCurrentUserID(c); err == nil {
adminUserID = uid
}
// Try admin user's configured provider/model
if adminUserID != 0 {
adminUser, err := h.userService.GetUserByID(ctx, adminUserID)
if err == nil && adminUser != nil && adminUser.AIProvider.Valid && adminUser.AIProvider.String != "" && adminUser.AIModel.Valid && adminUser.AIModel.String != "" {
reporterID = adminUser.ID
reporterUsername = adminUser.Username
reporterProvider = adminUser.AIProvider
reporterModel = adminUser.AIModel
foundProvider = true
h.logger.Info(ctx, "Falling back to admin user's AI provider", map[string]interface{}{"admin_id": adminUserID, "provider": adminUser.AIProvider.String, "model": adminUser.AIModel.String})
}
}
// If still not found, try global config first provider
if !foundProvider && h.config != nil && len(h.config.Providers) > 0 {
p := h.config.Providers[0]
if len(p.Models) > 0 {
// Use first provider and model from global config
reporterProvider = sql.NullString{String: p.Code, Valid: true}
reporterModel = sql.NullString{String: p.Models[0].Code, Valid: true}
reporterUsername = "system"
foundProvider = true
h.logger.Info(ctx, "Falling back to global configured AI provider", map[string]interface{}{"provider": p.Code, "model": p.Models[0].Code})
}
}
if !foundProvider {
h.logger.Warn(ctx, "No AI provider configured for reporting users and no fallback available", map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrAIConfigInvalid)
return
}
}
// Get saved API key for the reporter's configured provider
savedKey, apiKeyID, _ := h.userService.GetUserAPIKeyWithID(ctx, reporterID, reporterProvider.String)
userCfg := &models.UserAIConfig{
Provider: reporterProvider.String,
Model: reporterModel.String,
APIKey: savedKey,
Username: reporterUsername,
}
// Build AI chat request with question details and report reasons
// Use the template manager to render a structured prompt
// Prepare template data
questionContentJSON, _ := question.MarshalContentToJSON()
// Resolve schema for prompt; fail if none
schema, err := services.GetFixSchema(question.Type)
if err != nil {
h.logger.Error(ctx, "No schema available for question type", err, map[string]interface{}{"question_id": questionID, "type": question.Type})
HandleAppError(c, contextutils.ErrAIConfigInvalid)
return
}
// Read optional additional_context from POST body JSON
var body struct {
AdditionalContext string `json:"additional_context"`
}
_ = c.BindJSON(&body) // ignore error; body may be empty
tmplData := services.AITemplateData{
CurrentQuestionJSON: questionContentJSON,
ExampleContent: "", // will be filled below if example available
SchemaForPrompt: schema,
ReportReasons: []string{},
AdditionalContext: body.AdditionalContext,
}
if singleReason.Valid {
tmplData.ReportReasons = []string{singleReason.String}
}
// Load example for this question type if available
if ex, err := h.aiService.TemplateManager().LoadExample(string(question.Type)); err == nil {
tmplData.ExampleContent = ex
}
prompt, err := h.aiService.TemplateManager().RenderTemplate(services.AIFixPromptTemplate, tmplData)
if err != nil {
h.logger.Error(ctx, "Failed to render AI fix prompt", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to build AI prompt"))
return
}
// Use schema as grammar for providers that support it
supportsGrammar := h.aiService.SupportsGrammarField(userCfg.Provider)
var grammar string
if supportsGrammar {
grammar, err = services.GetFixSchema(question.Type)
if err != nil {
h.logger.Error(ctx, "No grammar schema available for question type", err, map[string]interface{}{"question_id": questionID, "type": question.Type})
HandleAppError(c, contextutils.ErrAIConfigInvalid)
return
}
} else {
grammar = ""
}
// Add user ID and API key ID to context for usage tracking
if reporterID != 0 {
ctx = contextutils.WithUserID(ctx, reporterID)
}
if apiKeyID != nil {
ctx = contextutils.WithAPIKeyID(ctx, *apiKeyID)
}
// Call AI service with constructed prompt and grammar
respStr, err := h.aiService.CallWithPrompt(ctx, userCfg, prompt, grammar)
if err != nil {
h.logger.Error(ctx, "AI service call failed", err, map[string]interface{}{"question_id": questionID, "provider": userCfg.Provider})
HandleAppError(c, contextutils.WrapError(err, "AI service error"))
return
}
// Attempt to parse AI response as JSON (and try to recover JSON substring if necessary)
var aiResp map[string]interface{}
if err := json.Unmarshal([]byte(respStr), &aiResp); err != nil {
start := strings.Index(respStr, "{")
end := strings.LastIndex(respStr, "}")
if start >= 0 && end > start {
candidate := respStr[start : end+1]
if err2 := json.Unmarshal([]byte(candidate), &aiResp); err2 != nil {
h.logger.Error(ctx, "Failed to parse AI response as JSON", err2, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrAIResponseInvalid)
return
}
} else {
h.logger.Error(ctx, "AI did not return JSON", nil, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrAIResponseInvalid)
return
}
}
// Start from the original question map so required top-level fields are preserved
originalMap := map[string]interface{}{}
if b, err := json.Marshal(question); err == nil {
_ = json.Unmarshal(b, &originalMap)
}
// Use helper to merge and normalize AI suggestion into original map
suggestion := MergeAISuggestion(originalMap, aiResp)
// Attach admin-provided additional context into suggestion metadata so frontend can display it
if body.AdditionalContext != "" {
suggestion["additional_context"] = body.AdditionalContext
}
// If query param apply=true present, apply suggestion directly and mark fixed
if strings.ToLower(c.Query("apply")) == "true" {
// Build update payload: use merged content and read answer/explanation from TOP LEVEL
updateContent := suggestion["content"].(map[string]interface{})
// Extract correct_answer from top level (support float64 from JSON)
correctAnswer := 0
if ca, ok := suggestion["correct_answer"]; ok {
switch v := ca.(type) {
case float64:
correctAnswer = int(v)
case int:
correctAnswer = v
}
}
// Extract explanation from top level
explanation := ""
if ex, ok := suggestion["explanation"].(string); ok {
explanation = ex
}
if err := h.questionService.UpdateQuestion(c.Request.Context(), questionID, updateContent, correctAnswer, explanation); err != nil {
h.logger.Error(c.Request.Context(), "Failed to update question with AI suggestion", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to apply suggestion"))
return
}
if err := h.questionService.MarkQuestionAsFixed(c.Request.Context(), questionID); err != nil {
h.logger.Warn(c.Request.Context(), "Failed to mark question as fixed after applying suggestion", map[string]interface{}{"question_id": questionID, "error": err.Error()})
}
db := h.questionService.DB()
if _, err := db.ExecContext(c.Request.Context(), `DELETE FROM question_reports WHERE question_id = $1`, questionID); err != nil {
h.logger.Warn(c.Request.Context(), "Failed to clear question reports after applying suggestion", map[string]interface{}{"question_id": questionID, "error": err.Error()})
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Suggestion applied"})
return
}
// Return original question and merged AI suggestion for frontend review
c.JSON(http.StatusOK, gin.H{
"original": question,
"suggestion": suggestion,
})
}
// ServeDatazJS - Removed: Use frontend admin interface instead
// GetAIConcurrencyStats returns AI service concurrency metrics
func (h *AdminHandler) GetAIConcurrencyStats(c *gin.Context) {
// Get stats from the local AI service instance
stats := h.aiService.GetConcurrencyStats()
c.JSON(http.StatusOK, gin.H{
"ai_concurrency": stats,
})
}
// --- Story Explorer (Admin) ---
// GetStoriesPaginated returns paginated stories with filters
func (h *AdminHandler) GetStoriesPaginated(c *gin.Context) {
if h.storyService == nil {
HandleAppError(c, contextutils.ErrInternalError)
return
}
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "search", "language", "status")
search := f["search"]
language := f["language"]
status := f["status"]
var userID *uint
if u := c.Query("user_id"); u != "" {
if parsed, err := strconv.Atoi(u); err == nil && parsed > 0 {
tmp := uint(parsed)
userID = &tmp
} else {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
stories, total, err := h.storyService.GetStoriesPaginated(c.Request.Context(), page, pageSize, search, language, status, userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get stories", err, map[string]interface{}{"page": page, "size": pageSize})
HandleAppError(c, contextutils.WrapError(err, "failed to get stories"))
return
}
// Map directly; convert to API struct for consistency
storyMaps := make([]map[string]interface{}, 0, len(stories))
for _, s := range stories {
apiS := convertStoryToAPI(&s)
m := map[string]interface{}{}
if b, err := json.Marshal(apiS); err == nil {
_ = json.Unmarshal(b, &m)
}
storyMaps = append(storyMaps, m)
}
c.JSON(http.StatusOK, gin.H{
"stories": storyMaps,
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
})
}
// GetStoryAdmin returns a full story with sections by ID
func (h *AdminHandler) GetStoryAdmin(c *gin.Context) {
if h.storyService == nil {
HandleAppError(c, contextutils.ErrInternalError)
return
}
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil || id <= 0 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
story, err := h.storyService.GetStoryAdmin(c.Request.Context(), uint(id))
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get story", err, map[string]interface{}{"story_id": id})
if strings.Contains(err.Error(), "story not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get story"))
return
}
c.JSON(http.StatusOK, convertStoryWithSectionsToAPI(story))
}
// GetSectionAdmin returns a section with questions by ID
func (h *AdminHandler) GetSectionAdmin(c *gin.Context) {
if h.storyService == nil {
HandleAppError(c, contextutils.ErrInternalError)
return
}
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil || id <= 0 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
section, err := h.storyService.GetSectionAdmin(c.Request.Context(), uint(id))
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get section", err, map[string]interface{}{"section_id": id})
if strings.Contains(err.Error(), "section not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get section"))
return
}
c.JSON(http.StatusOK, convertStorySectionWithQuestionsToAPI(section))
}
// DeleteStoryAdmin deletes a story by ID (admin only). Only archived or completed stories can be deleted.
func (h *AdminHandler) DeleteStoryAdmin(c *gin.Context) {
if h.storyService == nil {
HandleAppError(c, contextutils.ErrInternalError)
return
}
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil || id <= 0 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
if err := h.storyService.DeleteStoryAdmin(c.Request.Context(), uint(id)); err != nil {
h.logger.Error(c.Request.Context(), "Failed to delete story (admin)", err, map[string]interface{}{"story_id": id})
if strings.Contains(err.Error(), "not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
if strings.Contains(err.Error(), "cannot delete active story") {
HandleAppError(c, contextutils.ErrConflict)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to delete story"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Story deleted successfully"})
}
// ClearUserData removes all user activity data but keeps the users themselves
func (h *AdminHandler) ClearUserData(c *gin.Context) {
err := h.userService.ClearUserData(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to clear user data", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to clear user data"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "User data cleared successfully (users preserved)"})
}
// ClearDatabase completely resets the database to an empty state
func (h *AdminHandler) ClearDatabase(c *gin.Context) {
err := h.userService.ResetDatabase(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to clear database", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to clear database"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Database cleared successfully"})
}
// GetQuestion returns a single question by ID for editing
func (h *AdminHandler) GetQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
question, err := h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
c.JSON(http.StatusOK, question)
}
// GetUsersForQuestion returns the users assigned to a question
func (h *AdminHandler) GetUsersForQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
users, totalCount, err := h.questionService.GetUsersForQuestion(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get users for question", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to get users for question"))
return
}
c.JSON(http.StatusOK, gin.H{
"users": users,
"total_count": totalCount,
})
}
// AssignUsersToQuestion assigns multiple users to a question
func (h *AdminHandler) AssignUsersToQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var request struct {
UserIDs []int `json:"user_ids" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Validate non-empty user list
if len(request.UserIDs) == 0 {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Check if the question exists first
_, err = h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if errors.Is(err, sql.ErrNoRows) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get question"))
return
}
err = h.questionService.AssignUsersToQuestion(c.Request.Context(), questionID, request.UserIDs)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to assign users to question", err, map[string]interface{}{
"question_id": questionID,
"user_ids": request.UserIDs,
})
HandleAppError(c, contextutils.WrapError(err, "failed to assign users to question"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Users assigned to question successfully"})
}
// UnassignUsersFromQuestion removes multiple users from a question
func (h *AdminHandler) UnassignUsersFromQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var request struct {
UserIDs []int `json:"user_ids" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(contextutils.ErrorCodeInvalidInput, contextutils.SeverityWarn, "Invalid request body", "", err))
return
}
// Validate non-empty user list
if len(request.UserIDs) == 0 {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Check if the question exists first
_, err = h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if errors.Is(err, sql.ErrNoRows) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get question"))
return
}
err = h.questionService.UnassignUsersFromQuestion(c.Request.Context(), questionID, request.UserIDs)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to unassign users from question", err, map[string]interface{}{
"question_id": questionID,
"user_ids": request.UserIDs,
})
HandleAppError(c, contextutils.WrapError(err, "failed to unassign users from question"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Users unassigned from question successfully"})
}
// DeleteQuestion deletes a question by ID
func (h *AdminHandler) DeleteQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
err = h.questionService.DeleteQuestion(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to delete question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to delete question"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Question deleted successfully"})
}
// GetQuestionsPaginated returns paginated questions with response statistics
func (h *AdminHandler) GetQuestionsPaginated(c *gin.Context) {
userIDStr := c.Query("user_id")
if userIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 10, 100)
filters := ParseFilters(c, "search", "type", "status")
search := filters["search"]
typeFilter := filters["type"]
statusFilter := filters["status"]
// Get questions with filters
questions, total, err := h.questionService.GetQuestionsPaginated(
c.Request.Context(),
userID,
page,
pageSize,
search,
typeFilter,
statusFilter,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get paginated questions", err, map[string]interface{}{
"user_id": userID,
"page": page,
"size": pageSize,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get questions"))
return
}
c.JSON(http.StatusOK, gin.H{
"questions": func() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(questions))
for _, q := range questions {
out = append(out, convertQuestionWithStatsToAPIMap(q))
}
return out
}(),
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
})
}
// GetAllQuestions returns all questions with pagination and filtering
func (h *AdminHandler) GetAllQuestions(c *gin.Context) {
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "search", "type", "status", "language", "level")
search := f["search"]
typeFilter := f["type"]
statusFilter := f["status"]
languageFilter := f["language"]
levelFilter := f["level"]
userIDStr := c.Query("user_id")
// Parse user_id if provided
var userID *int
if userIDStr != "" {
uid, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
userID = &uid
}
// Get questions with filters
questions, total, err := h.questionService.GetAllQuestionsPaginated(
c.Request.Context(),
page,
pageSize,
search,
typeFilter,
statusFilter,
languageFilter,
levelFilter,
userID,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get all questions", err, map[string]interface{}{
"page": page,
"size": pageSize,
"search": search,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get questions"))
return
}
// Get stats
stats, err := h.questionService.GetQuestionStats(c.Request.Context())
if err != nil {
h.logger.Warn(c.Request.Context(), "Failed to get question stats", map[string]interface{}{"error": err.Error()})
stats = map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"questions": func() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(questions))
for _, q := range questions {
out = append(out, convertQuestionWithStatsToAPIMap(q))
}
return out
}(),
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
"stats": stats,
})
}
// GetReportedQuestionsPaginated returns reported questions with pagination and filtering
func (h *AdminHandler) GetReportedQuestionsPaginated(c *gin.Context) {
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "search", "type", "language", "level")
search := f["search"]
typeFilter := f["type"]
languageFilter := f["language"]
levelFilter := f["level"]
// Get reported questions with filters
questions, total, err := h.questionService.GetReportedQuestionsPaginated(
c.Request.Context(),
page,
pageSize,
search,
typeFilter,
languageFilter,
levelFilter,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get reported questions", err, map[string]interface{}{
"page": page,
"size": pageSize,
"search": search,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get reported questions"))
return
}
// Get reported questions stats
stats, err := h.questionService.GetReportedQuestionsStats(c.Request.Context())
if err != nil {
h.logger.Warn(c.Request.Context(), "Failed to get reported questions stats", map[string]interface{}{"error": err.Error()})
stats = map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"questions": func() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(questions))
for _, q := range questions {
out = append(out, convertQuestionWithStatsToAPIMap(q))
}
return out
}(),
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
"stats": stats,
})
}
// ClearUserDataForUser removes all user activity data for a specific user but keeps the user record
func (h *AdminHandler) ClearUserDataForUser(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "clear_user_data_for_user")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before attempting to clear data
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for clear data operation", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
err = h.userService.ClearUserDataForUser(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to clear user data for user", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to clear user data for user"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "User data cleared successfully (user preserved)"})
}
// GetConfigz returns the merged config as pretty-printed JSON
func (h *AdminHandler) GetConfigz(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_configz")
defer observability.FinishSpan(span, nil)
c.IndentedJSON(http.StatusOK, h.config)
}
// GetRoles returns all available roles in the system
func (h *AdminHandler) GetRoles(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_roles")
defer observability.FinishSpan(span, nil)
// For now, return hardcoded roles since we don't have a role service
// In a real implementation, you'd query the database
roles := []models.Role{
{ID: 1, Name: "user", Description: "Normal site access", CreatedAt: time.Now(), UpdatedAt: time.Now()},
{ID: 2, Name: "admin", Description: "Administrative access to all features", CreatedAt: time.Now(), UpdatedAt: time.Now()},
}
c.JSON(http.StatusOK, gin.H{"roles": roles})
}
// GetUserRoles returns all roles for a specific user
func (h *AdminHandler) GetUserRoles(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_roles")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before getting roles
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for roles operation", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
roles, err := h.userService.GetUserRoles(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user roles", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user roles"))
return
}
c.JSON(http.StatusOK, gin.H{"roles": roles})
}
// AssignRole assigns a role to a user
func (h *AdminHandler) AssignRole(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "assign_role")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before assigning role
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for role assignment", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
var req struct {
RoleID int `json:"role_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(contextutils.ErrorCodeInvalidInput, contextutils.SeverityWarn, "Invalid request body", "", err))
return
}
// Ensure the requester is allowed (self or admin). Route is admin-only, but keep explicit check.
currentUserID, err := GetCurrentUserID(c)
if err == nil {
if err := RequireSelfOrAdmin(ctx, h.userService, currentUserID, userID); err != nil {
if errors.Is(err, ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(ctx, "Failed to check authorization", err, map[string]interface{}{"user_id": currentUserID})
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
}
err = h.userService.AssignRole(ctx, userID, req.RoleID)
if err != nil {
h.logger.Error(ctx, "Failed to assign role to user", err, map[string]interface{}{"user_id": userID, "role_id": req.RoleID})
HandleAppError(c, contextutils.WrapError(err, "failed to assign role"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Role assigned successfully"})
}
// RemoveRole removes a role from a user
func (h *AdminHandler) RemoveRole(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "remove_role")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before removing role
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for role removal", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
roleIDStr := c.Param("roleId")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Ensure the requester is allowed (self or admin). Route is admin-only, but keep explicit check.
currentUserID, err := GetCurrentUserID(c)
if err == nil {
if err := RequireSelfOrAdmin(ctx, h.userService, currentUserID, userID); err != nil {
if errors.Is(err, ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(ctx, "Failed to check authorization", err, map[string]interface{}{"user_id": currentUserID})
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
}
err = h.userService.RemoveRole(ctx, userID, roleID)
if err != nil {
h.logger.Error(ctx, "Failed to remove role", err, map[string]interface{}{"user_id": userID, "role_id": roleID})
// Check if it's a "user does not have role" error
if strings.Contains(err.Error(), "does not have role") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check if it's a "user not found" or "role not found" error
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to remove role"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Role removed successfully"})
}
// GetUsageStats returns usage statistics for the admin interface
func (h *AdminHandler) GetUsageStats(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_usage_stats")
defer observability.FinishSpan(span, nil)
if h.usageStatsSvc == nil {
HandleAppError(c, contextutils.ErrInternalError)
return
}
// Get all usage stats
stats, err := h.usageStatsSvc.GetAllUsageStats(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to get usage stats", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to get usage stats"))
return
}
// Group stats by service and month for easier frontend consumption
serviceStats := make(map[string]map[string]map[string]interface{})
monthlyTotals := make(map[string]map[string]interface{})
// Track cache statistics across all services
var totalCacheHitsRequests, totalCacheHitsCharacters, totalCacheMissesRequests int
for _, stat := range stats {
serviceName := stat.ServiceName
usageType := stat.UsageType
month := stat.UsageMonth.Format("2006-01")
if serviceStats[serviceName] == nil {
serviceStats[serviceName] = make(map[string]map[string]interface{})
}
if serviceStats[serviceName][month] == nil {
serviceStats[serviceName][month] = make(map[string]interface{})
}
serviceStats[serviceName][month][usageType] = map[string]interface{}{
"characters_used": stat.CharactersUsed,
"requests_made": stat.RequestsMade,
"quota": h.usageStatsSvc.GetMonthlyQuota(serviceName),
}
// Accumulate cache statistics
switch usageType {
case "translation_cache_hit":
totalCacheHitsRequests += stat.RequestsMade
totalCacheHitsCharacters += stat.CharactersUsed
case "translation_cache_miss":
totalCacheMissesRequests += stat.RequestsMade
}
// Accumulate monthly totals (only for actual translations, not cache)
if usageType == "translation" {
if monthlyTotals[month] == nil {
monthlyTotals[month] = make(map[string]interface{})
}
if monthlyTotals[month][serviceName] == nil {
monthlyTotals[month][serviceName] = map[string]interface{}{
"total_characters": 0,
"total_requests": 0,
}
}
totalChars := monthlyTotals[month][serviceName].(map[string]interface{})["total_characters"].(int) + stat.CharactersUsed
totalReqs := monthlyTotals[month][serviceName].(map[string]interface{})["total_requests"].(int) + stat.RequestsMade
monthlyTotals[month][serviceName].(map[string]interface{})["total_characters"] = totalChars
monthlyTotals[month][serviceName].(map[string]interface{})["total_requests"] = totalReqs
}
}
// Calculate cache hit rate
totalCacheRequests := totalCacheHitsRequests + totalCacheMissesRequests
var cacheHitRate float64
if totalCacheRequests > 0 {
cacheHitRate = (float64(totalCacheHitsRequests) / float64(totalCacheRequests)) * 100
}
c.JSON(http.StatusOK, gin.H{
"usage_stats": serviceStats,
"monthly_totals": monthlyTotals,
"services": []string{"google"}, // Currently only Google Translate
"cache_stats": gin.H{
"total_cache_hits_requests": totalCacheHitsRequests,
"total_cache_hits_characters": totalCacheHitsCharacters,
"total_cache_misses_requests": totalCacheMissesRequests,
"cache_hit_rate": cacheHitRate,
},
})
}
// GetUsageStatsByService returns usage statistics for a specific service
func (h *AdminHandler) GetUsageStatsByService(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_usage_stats_by_service")
defer observability.FinishSpan(span, nil)
serviceName := c.Param("service")
if serviceName == "" {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Validate service name against configured translation providers
if !h.config.Translation.Enabled {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
isValidService := false
for providerCode := range h.config.Translation.Providers {
if providerCode == serviceName {
isValidService = true
break
}
}
if !isValidService {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
if h.usageStatsSvc == nil {
HandleAppError(c, contextutils.ErrInternalError)
return
}
stats, err := h.usageStatsSvc.GetUsageStatsByService(ctx, serviceName)
if err != nil {
h.logger.Error(ctx, "Failed to get usage stats by service", err, map[string]interface{}{"service": serviceName})
HandleAppError(c, contextutils.WrapError(err, "failed to get usage stats"))
return
}
// Format for frontend consumption
monthlyData := make([]map[string]interface{}, 0)
for _, stat := range stats {
// Only show quota for actual translation usage, not for cache hits/misses
var quota interface{}
if stat.UsageType == "translation" {
quota = h.usageStatsSvc.GetMonthlyQuota(serviceName)
} else {
quota = nil
}
monthlyData = append(monthlyData, map[string]interface{}{
"month": stat.UsageMonth.Format("2006-01"),
"usage_type": stat.UsageType,
"characters_used": stat.CharactersUsed,
"requests_made": stat.RequestsMade,
"quota": quota,
})
}
c.JSON(http.StatusOK, gin.H{
"service": serviceName,
"data": monthlyData,
})
}
// calculateUserAggregateStats calculates aggregate statistics for all users
func calculateUserAggregateStats(ctx context.Context, users []models.User, learningService services.LearningServiceInterface, logger *observability.Logger) map[string]interface{} {
stats := map[string]interface{}{
"total_users": len(users),
"by_language": make(map[string]int),
"by_level": make(map[string]int),
"by_ai_provider": make(map[string]int),
"by_ai_model": make(map[string]int),
"ai_enabled": 0,
"ai_disabled": 0,
"active_users": 0,
"inactive_users": 0,
"total_questions_answered": 0,
"total_correct_answers": 0,
"average_accuracy": 0.0,
}
activeThreshold := time.Now().AddDate(0, 0, -7)
for _, user := range users {
lang := "unknown"
if user.PreferredLanguage.Valid {
lang = user.PreferredLanguage.String
}
stats["by_language"].(map[string]int)[lang]++
level := "unknown"
if user.CurrentLevel.Valid {
level = user.CurrentLevel.String
}
stats["by_level"].(map[string]int)[level]++
provider := "none"
if user.AIProvider.Valid {
provider = user.AIProvider.String
}
stats["by_ai_provider"].(map[string]int)[provider]++
model := "none"
if user.AIModel.Valid {
model = user.AIModel.String
}
stats["by_ai_model"].(map[string]int)[model]++
if user.AIEnabled.Valid && user.AIEnabled.Bool {
aiEnabled := stats["ai_enabled"].(int)
stats["ai_enabled"] = aiEnabled + 1
} else {
aiDisabled := stats["ai_disabled"].(int)
stats["ai_disabled"] = aiDisabled + 1
}
if user.LastActive.Valid {
lastActive := user.LastActive.Time
if lastActive.After(activeThreshold) {
activeUsers := stats["active_users"].(int)
stats["active_users"] = activeUsers + 1
} else {
inactiveUsers := stats["inactive_users"].(int)
stats["inactive_users"] = inactiveUsers + 1
}
} else {
inactiveUsers := stats["inactive_users"].(int)
stats["inactive_users"] = inactiveUsers + 1
}
progress, err := learningService.GetUserProgress(ctx, user.ID)
if err != nil {
logger.Warn(ctx, "Failed to get progress for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
continue
}
if progress != nil {
totalAnswered := stats["total_questions_answered"].(int)
stats["total_questions_answered"] = totalAnswered + progress.TotalQuestions
totalCorrect := stats["total_correct_answers"].(int)
stats["total_correct_answers"] = totalCorrect + progress.CorrectAnswers
}
}
totalAnswered := stats["total_questions_answered"].(int)
if totalAnswered > 0 {
stats["average_accuracy"] = float64(stats["total_correct_answers"].(int)) / float64(totalAnswered) * 100.0
}
return stats
}
package handlers
import (
"encoding/json"
"fmt"
"strings"
)
// MergeAISuggestion merges AI response into the original question map.
// It ensures top-level metadata from original are preserved and AI-provided
// content is merged into original["content"].
//
// Canonical location for `correct_answer` and `explanation` is the TOP LEVEL of
// the returned object. Any occurrences under `content` are removed.
func MergeAISuggestion(original, aiResp map[string]interface{}) map[string]interface{} {
// copy original to avoid mutating caller's map
out := map[string]interface{}{}
b, _ := json.Marshal(original)
_ = json.Unmarshal(b, &out)
// ensure content map exists
contentIface := out["content"]
contentMap, _ := contentIface.(map[string]interface{})
if contentMap == nil {
contentMap = map[string]interface{}{}
out["content"] = contentMap
}
// merge ai content into content map
if aiContentRaw, ok := aiResp["content"]; ok {
if aiContentMap, ok2 := aiContentRaw.(map[string]interface{}); ok2 {
for k, v := range aiContentMap {
contentMap[k] = v
}
}
}
// Ensure answer/explanation live at TOP LEVEL on the output, not inside content
// Prefer values from the AI response when present.
if ca, ok := aiResp["correct_answer"]; ok {
out["correct_answer"] = ca
}
if ex, ok := aiResp["explanation"]; ok {
out["explanation"] = ex
}
// Remove any duplicates that may exist inside content
delete(contentMap, "correct_answer")
delete(contentMap, "explanation")
if cr, ok := aiResp["change_reason"]; ok {
out["change_reason"] = cr
}
NormalizeContent(contentMap)
return out
}
// NormalizeContent attempts to sanitize content fields: options->[]string and
// simple string coercions for human-readable fields. Answer/explanation stay at
// top level and are not touched here.
func NormalizeContent(contentMap map[string]interface{}) {
// normalize options
if optsRaw, ok := contentMap["options"]; ok {
switch opts := optsRaw.(type) {
case []interface{}:
seen := map[string]bool{}
var out []string
for _, it := range opts {
s, ok := it.(string)
if !ok {
continue
}
s = strings.TrimSpace(s)
if s == "" {
continue
}
if !seen[s] {
out = append(out, s)
seen[s] = true
}
}
contentMap["options"] = out
case []string:
// ok
case string:
var parsed []string
if err := json.Unmarshal([]byte(opts), &parsed); err == nil {
contentMap["options"] = parsed
} else {
parts := strings.FieldsFunc(opts, func(r rune) bool { return r == '\n' || r == ',' })
var out []string
seen := map[string]bool{}
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if !seen[p] {
out = append(out, p)
seen[p] = true
}
}
contentMap["options"] = out
}
default:
delete(contentMap, "options")
}
}
// ensure options slice is []string
if optsI, ok := contentMap["options"].([]interface{}); ok {
var out []string
for _, it := range optsI {
if s, ok := it.(string); ok {
out = append(out, s)
}
}
contentMap["options"] = out
}
// Ensure no stray correct_answer under content
delete(contentMap, "correct_answer")
// ensure simple string fields
for _, k := range []string{"explanation", "question", "passage", "sentence"} {
if v, ok := contentMap[k]; ok {
switch t := v.(type) {
case string:
// ok
default:
contentMap[k] = fmt.Sprint(t)
}
}
}
}
package handlers
import (
"net/http"
"strconv"
"strings"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
)
// AIConversationHandler handles AI conversation-related HTTP requests
type AIConversationHandler struct {
conversationService services.ConversationServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewAIConversationHandler creates a new AIConversationHandler
func NewAIConversationHandler(
conversationService services.ConversationServiceInterface,
cfg *config.Config,
logger *observability.Logger,
) *AIConversationHandler {
return &AIConversationHandler{
conversationService: conversationService,
cfg: cfg,
logger: logger,
}
}
// GetConversations handles GET /v1/ai/conversations
func (h *AIConversationHandler) GetConversations(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_conversations")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse query parameters
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit < 1 || limit > 100 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
offset, err := strconv.Atoi(offsetStr)
if err != nil || offset < 0 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
attribute.Int("offset", offset),
)
// Get conversations for the user
conversations, total, err := h.conversationService.GetUserConversations(ctx, uint(userID), limit, offset)
if err != nil {
h.logger.Error(ctx, "Failed to get user conversations", err, map[string]interface{}{
"user_id": userID,
"limit": limit,
"offset": offset,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get conversations"))
return
}
// Enrich with message counts to support UI badges without loading messages
counts, err := h.conversationService.GetUserMessageCounts(ctx, uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to get message counts", err, map[string]interface{}{
"user_id": userID,
})
// Not fatal; continue without counts
counts = map[string]int{}
}
// Inject message_count into each conversation via a response wrapper to keep type safety
type conversationWithCount struct {
api.Conversation
MessageCount int `json:"message_count"`
}
convsWithCount := make([]conversationWithCount, 0, len(conversations))
for _, conv := range conversations {
idStr := conv.Id.String()
convsWithCount = append(convsWithCount, conversationWithCount{
Conversation: conv,
MessageCount: counts[idStr],
})
}
// Add total count to response
response := gin.H{
"conversations": convsWithCount,
"total": total,
"limit": limit,
"offset": offset,
}
c.JSON(http.StatusOK, response)
}
// CreateConversation handles POST /v1/ai/conversations
func (h *AIConversationHandler) CreateConversation(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "create_ai_conversation")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse request body
var req api.CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("conversation_title", req.Title),
)
// Create conversation
conversation, err := h.conversationService.CreateConversation(ctx, uint(userID), &req)
if err != nil {
h.logger.Error(ctx, "Failed to create conversation", err, map[string]interface{}{
"user_id": userID,
"title": req.Title,
})
HandleAppError(c, contextutils.WrapError(err, "failed to create conversation"))
return
}
c.JSON(http.StatusCreated, conversation)
}
// GetConversation handles GET /v1/ai/conversations/{id}
func (h *AIConversationHandler) GetConversation(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_conversation")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse conversation ID parameter
conversationID := c.Param("id")
if conversationID == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Validate UUID format
if _, err := uuid.Parse(conversationID); err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("conversation_id", conversationID),
)
// Get conversation with messages
conversation, err := h.conversationService.GetConversation(ctx, conversationID, uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to get conversation", err, map[string]interface{}{
"user_id": userID,
"conversation_id": conversationID,
})
// Check if it's a conversation not found error
if strings.Contains(err.Error(), "conversation not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get conversation"))
return
}
c.JSON(http.StatusOK, conversation)
}
// UpdateConversation handles PUT /v1/ai/conversations/{id}
func (h *AIConversationHandler) UpdateConversation(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "update_ai_conversation")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse conversation ID parameter
conversationID := c.Param("id")
if conversationID == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Validate UUID format
if _, err := uuid.Parse(conversationID); err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Parse request body
var req api.UpdateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("conversation_id", conversationID),
attribute.String("new_title", req.Title),
)
// Update conversation
conversation, err := h.conversationService.UpdateConversation(ctx, conversationID, uint(userID), &req)
if err != nil {
h.logger.Error(ctx, "Failed to update conversation", err, map[string]interface{}{
"user_id": userID,
"conversation_id": conversationID,
"new_title": req.Title,
})
// Check if it's a conversation not found error
if strings.Contains(err.Error(), "conversation not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update conversation"))
return
}
c.JSON(http.StatusOK, conversation)
}
// DeleteConversation handles DELETE /v1/ai/conversations/{id}
func (h *AIConversationHandler) DeleteConversation(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "delete_ai_conversation")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse conversation ID parameter
conversationID := c.Param("id")
if conversationID == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Validate UUID format
if _, err := uuid.Parse(conversationID); err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("conversation_id", conversationID),
)
// Delete conversation and all its messages
err := h.conversationService.DeleteConversation(ctx, conversationID, uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to delete conversation", err, map[string]interface{}{
"user_id": userID,
"conversation_id": conversationID,
})
// Check if it's a conversation not found error
if strings.Contains(err.Error(), "conversation not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to delete conversation"))
return
}
c.Status(http.StatusNoContent)
}
// AddMessage handles POST /v1/ai/conversations/{conversationId}/messages
func (h *AIConversationHandler) AddMessage(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "add_ai_message")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse conversation ID parameter
conversationID := c.Param("conversationId")
if conversationID == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Validate UUID format
if _, err := uuid.Parse(conversationID); err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Parse request body
var req api.CreateMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Calculate content length for observability
contentLength := 0
if req.Content.Text != nil {
contentLength = len(*req.Content.Text)
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("conversation_id", conversationID),
attribute.String("message_role", string(req.Role)),
attribute.Int("message_content_length", contentLength),
)
// Add message to conversation
createdMessage, err := h.conversationService.AddMessage(ctx, conversationID, uint(userID), &req)
if err != nil {
h.logger.Error(ctx, "Failed to add message to conversation", err, map[string]interface{}{
"user_id": userID,
"conversation_id": conversationID,
"message_role": req.Role,
})
// Check if it's a conversation not found error
if strings.Contains(err.Error(), "conversation not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to add message"))
return
}
c.JSON(http.StatusCreated, createdMessage)
}
// SearchConversations handles GET /v1/ai/search
func (h *AIConversationHandler) SearchConversations(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "search_ai_conversations")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse query parameters
query := c.Query("q")
if query == "" {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit < 1 || limit > 100 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
offset, err := strconv.Atoi(offsetStr)
if err != nil || offset < 0 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("search_query", query),
attribute.Int("limit", limit),
attribute.Int("offset", offset),
)
// Search conversations
conversations, total, err := h.conversationService.SearchConversations(ctx, uint(userID), query, limit, offset)
if err != nil {
h.logger.Error(ctx, "Failed to search conversations", err, map[string]interface{}{
"user_id": userID,
"query": query,
"limit": limit,
"offset": offset,
})
HandleAppError(c, contextutils.WrapError(err, "failed to search conversations"))
return
}
// Add total count to response
response := gin.H{
"conversations": conversations,
"query": query,
"total": total,
"limit": limit,
"offset": offset,
}
c.JSON(http.StatusOK, response)
}
// ToggleMessageBookmark handles PUT /v1/ai/conversations/bookmark
func (h *AIConversationHandler) ToggleMessageBookmark(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "toggle_message_bookmark")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse request body
var req struct {
ConversationID string `json:"conversation_id" binding:"required"`
MessageID string `json:"message_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Validate UUID formats
if _, err := uuid.Parse(req.ConversationID); err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
if _, err := uuid.Parse(req.MessageID); err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("conversation_id", req.ConversationID),
attribute.String("message_id", req.MessageID),
)
// Toggle message bookmark
newBookmarkedStatus, err := h.conversationService.ToggleMessageBookmark(ctx, req.ConversationID, req.MessageID, uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to toggle message bookmark", err, map[string]interface{}{
"user_id": userID,
"conversation_id": req.ConversationID,
"message_id": req.MessageID,
})
// Check if it's a conversation or message not found error
if strings.Contains(err.Error(), "not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to toggle message bookmark"))
return
}
c.JSON(http.StatusOK, gin.H{
"bookmarked": newBookmarkedStatus,
})
}
// GetBookmarkedMessages handles GET /v1/ai/bookmarks
func (h *AIConversationHandler) GetBookmarkedMessages(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_bookmarked_messages")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse query parameters
query := c.DefaultQuery("q", "")
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit < 1 || limit > 100 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
offset, err := strconv.Atoi(offsetStr)
if err != nil || offset < 0 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("search_query", query),
attribute.Int("limit", limit),
attribute.Int("offset", offset),
)
// Get bookmarked messages
messages, total, err := h.conversationService.GetBookmarkedMessages(ctx, uint(userID), query, limit, offset)
if err != nil {
h.logger.Error(ctx, "Failed to get bookmarked messages", err, map[string]interface{}{
"user_id": userID,
"query": query,
"limit": limit,
"offset": offset,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get bookmarked messages"))
return
}
// Add total count to response
response := gin.H{
"messages": messages,
"query": query,
"total": total,
"limit": limit,
"offset": offset,
}
c.JSON(http.StatusOK, response)
}
package handlers
import (
"net/http"
"strconv"
"quizapp/internal/api"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// AuthAPIKeyHandler handles authentication API key related HTTP requests
type AuthAPIKeyHandler struct {
apiKeyService services.AuthAPIKeyServiceInterface
logger *observability.Logger
}
// NewAuthAPIKeyHandler creates a new AuthAPIKeyHandler instance
func NewAuthAPIKeyHandler(apiKeyService services.AuthAPIKeyServiceInterface, logger *observability.Logger) *AuthAPIKeyHandler {
return &AuthAPIKeyHandler{
apiKeyService: apiKeyService,
logger: logger,
}
}
// CreateAPIKey handles POST /v1/api-keys
func (h *AuthAPIKeyHandler) CreateAPIKey(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "CreateAPIKey")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := c.Get(middleware.UserIDKey)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
userIDInt, ok := userID.(int)
if !ok {
HandleAppError(c, contextutils.ErrInternalError)
return
}
span.SetAttributes(attribute.Int("user_id", userIDInt))
// Parse request body
var req struct {
KeyName string `json:"key_name" binding:"required"`
PermissionLevel string `json:"permission_level" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
span.SetAttributes(
attribute.String("key_name", req.KeyName),
attribute.String("permission_level", req.PermissionLevel),
)
// Create API key
apiKey, rawKey, err := h.apiKeyService.CreateAPIKey(ctx, userIDInt, req.KeyName, req.PermissionLevel)
if err != nil {
h.logger.Error(ctx, "Failed to create API key", err, map[string]interface{}{
"user_id": userIDInt,
"key_name": req.KeyName,
"permission_level": req.PermissionLevel,
})
HandleAppError(c, err)
return
}
span.SetAttributes(attribute.Int("api_key_id", apiKey.ID))
// Return the full key ONCE (this is the only time it will be shown)
c.JSON(http.StatusCreated, gin.H{
"id": apiKey.ID,
"key_name": apiKey.KeyName,
"key": rawKey, // Full key - only shown once!
"key_prefix": apiKey.KeyPrefix,
"permission_level": apiKey.PermissionLevel,
"created_at": apiKey.CreatedAt,
"message": "Save this API key now. You won't be able to see it again!",
})
}
// ListAPIKeys handles GET /v1/api-keys
func (h *AuthAPIKeyHandler) ListAPIKeys(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "ListAPIKeys")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := c.Get(middleware.UserIDKey)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
userIDInt, ok := userID.(int)
if !ok {
HandleAppError(c, contextutils.ErrInternalError)
return
}
span.SetAttributes(attribute.Int("user_id", userIDInt))
// List API keys
apiKeys, err := h.apiKeyService.ListAPIKeys(ctx, userIDInt)
if err != nil {
h.logger.Error(ctx, "Failed to list API keys", err, map[string]interface{}{"user_id": userIDInt})
HandleAppError(c, err)
return
}
span.SetAttributes(attribute.Int("count", len(apiKeys)))
// Convert to generated API types to ensure schema-correct serialization
apiSummaries := convertAuthAPIKeysToAPI(apiKeys)
count := len(apiSummaries)
resp := api.APIKeysListResponse{
ApiKeys: &apiSummaries,
Count: &count,
}
c.JSON(http.StatusOK, resp)
}
// DeleteAPIKey handles DELETE /v1/api-keys/:id
func (h *AuthAPIKeyHandler) DeleteAPIKey(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "DeleteAPIKey")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := c.Get(middleware.UserIDKey)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
userIDInt, ok := userID.(int)
if !ok {
HandleAppError(c, contextutils.ErrInternalError)
return
}
// Get key ID from URL parameter
keyIDStr := c.Param("id")
keyID, err := strconv.Atoi(keyIDStr)
if err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid API key ID",
"",
err,
))
return
}
span.SetAttributes(
attribute.Int("user_id", userIDInt),
attribute.Int("key_id", keyID),
)
// Delete API key
err = h.apiKeyService.DeleteAPIKey(ctx, userIDInt, keyID)
if err != nil {
h.logger.Error(ctx, "Failed to delete API key", err, map[string]interface{}{
"user_id": userIDInt,
"key_id": keyID,
})
HandleAppError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "API key deleted successfully",
})
}
// TestRead handles GET /v1/api-keys/test-read
// Requires API key auth (readonly or full). Returns basic info for verification.
func (h *AuthAPIKeyHandler) TestRead(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "TestAPIKeyRead")
defer observability.FinishSpan(span, nil)
// Extract context set by middleware
userID := c.GetInt(middleware.UserIDKey)
username := c.GetString(middleware.UsernameKey)
apiKeyID := c.GetInt(middleware.APIKeyIDKey)
// Fetch permission level using the key id
var permissionLevel string
if apiKeyID != 0 && userID != 0 {
if apiKey, err := h.apiKeyService.GetAPIKeyByID(ctx, userID, apiKeyID); err == nil && apiKey != nil {
permissionLevel = apiKey.PermissionLevel
}
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"user_id": userID,
"username": username,
"permission_level": permissionLevel,
"api_key_id": apiKeyID,
"method": c.Request.Method,
})
}
// TestWrite handles POST /v1/api-keys/test-write
// Requires API key auth. Middleware enforces permission by HTTP method.
func (h *AuthAPIKeyHandler) TestWrite(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "TestAPIKeyWrite")
defer observability.FinishSpan(span, nil)
userID := c.GetInt(middleware.UserIDKey)
username := c.GetString(middleware.UsernameKey)
apiKeyID := c.GetInt(middleware.APIKeyIDKey)
var permissionLevel string
if apiKeyID != 0 && userID != 0 {
if apiKey, err := h.apiKeyService.GetAPIKeyByID(ctx, userID, apiKeyID); err == nil && apiKey != nil {
permissionLevel = apiKey.PermissionLevel
}
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"user_id": userID,
"username": username,
"permission_level": permissionLevel,
"api_key_id": apiKeyID,
"method": c.Request.Method,
})
}
package handlers
import (
"crypto/rand"
"errors"
"net/http"
"regexp"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
openapi_types "github.com/oapi-codegen/runtime/types"
"go.opentelemetry.io/otel/attribute"
)
// AuthHandler handles authentication related HTTP requests
type AuthHandler struct {
userService services.UserServiceInterface
oauthService *services.OAuthService
config *config.Config
logger *observability.Logger
}
// NewAuthHandler creates a new AuthHandler instance
func NewAuthHandler(userService services.UserServiceInterface, oauthService *services.OAuthService, cfg *config.Config, logger *observability.Logger) *AuthHandler {
return &AuthHandler{
userService: userService,
oauthService: oauthService,
config: cfg,
logger: logger,
}
}
// Login handles user login requests
func (h *AuthHandler) Login(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "login")
defer observability.FinishSpan(span, nil)
var req api.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Set span attributes for observability
span.SetAttributes(
attribute.String("auth.username", req.Username),
attribute.Bool("auth.password_provided", req.Password != ""),
)
// Authenticate user against database
user, err := h.userService.AuthenticateUser(c.Request.Context(), req.Username, req.Password)
if err != nil {
h.logger.Error(c.Request.Context(), "Authentication failed for user", err, map[string]interface{}{"username": req.Username})
HandleAppError(c, contextutils.ErrInvalidCredentials)
return
}
if user == nil {
HandleAppError(c, contextutils.ErrInvalidCredentials)
return
}
// Update span attributes with user info
span.SetAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.username", user.Username),
attribute.Bool("user.email_provided", user.Email.Valid),
attribute.String("user.language", user.PreferredLanguage.String),
attribute.String("user.level", user.CurrentLevel.String),
)
// Update last active
if err := h.userService.UpdateLastActive(c.Request.Context(), user.ID); err != nil {
// Log error but don't fail login
// In production, you'd want proper logging here
h.logger.Warn(c.Request.Context(), "Failed to update last active for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
}
// Create session
session := sessions.Default(c)
session.Set(middleware.UserIDKey, user.ID)
session.Set(middleware.UsernameKey, user.Username)
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Failed to save session", err, map[string]interface{}{"error": err.Error()})
HandleAppError(c, contextutils.WrapError(err, "failed to create session"))
return
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
// Return user info (without API key)
c.JSON(http.StatusOK, api.LoginResponse{
Success: boolPtr(true),
Message: stringPtr("Login successful"),
User: &apiUser,
})
}
// Logout handles user logout requests
func (h *AuthHandler) Logout(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "logout")
defer observability.FinishSpan(span, nil)
// Get user info before clearing session for tracing
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
username := session.Get(middleware.UsernameKey)
// Set span attributes
if userID != nil {
span.SetAttributes(attribute.Int("user.id", userID.(int)))
}
if username != nil {
span.SetAttributes(attribute.String("user.username", username.(string)))
}
session.Clear()
if err := session.Save(); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to clear session"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{
Success: true,
Message: stringPtr("Logout successful"),
})
}
// Status returns the current authentication status
func (h *AuthHandler) Status(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "status")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
if userID == nil {
span.SetAttributes(attribute.Bool("auth.authenticated", false))
c.JSON(http.StatusOK, gin.H{
"authenticated": false,
"user": nil,
})
return
}
span.SetAttributes(
attribute.Bool("auth.authenticated", true),
attribute.Int("user.id", userID.(int)),
)
user, err := h.userService.GetUserByID(c.Request.Context(), userID.(int))
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting user by ID", err, map[string]interface{}{"user_id": userID.(int)})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if user == nil {
// User not found, clear session
session.Clear()
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Error saving session", err, map[string]interface{}{"error": err.Error()})
}
span.SetAttributes(attribute.Bool("auth.user_found", false))
c.JSON(http.StatusOK, gin.H{
"authenticated": false,
"user": nil,
})
return
}
// Update span attributes with user info
span.SetAttributes(
attribute.Bool("auth.user_found", true),
attribute.String("user.username", user.Username),
attribute.Bool("user.email_provided", user.Email.Valid),
attribute.String("user.language", user.PreferredLanguage.String),
attribute.String("user.level", user.CurrentLevel.String),
attribute.Bool("user.ai_enabled", user.AIEnabled.Bool),
attribute.String("user.ai_provider", user.AIProvider.String),
attribute.String("user.ai_model", user.AIModel.String),
)
// Update last active timestamp
if err := h.userService.UpdateLastActive(c.Request.Context(), user.ID); err != nil {
h.logger.Error(c.Request.Context(), "Error updating last active", err, map[string]interface{}{"user_id": user.ID})
// Don't fail the request for this error
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
c.JSON(http.StatusOK, gin.H{
"authenticated": true,
"user": &apiUser,
})
}
// Check is a lightweight auth-check endpoint intended for reverse proxy auth_request.
// It requires authentication via middleware and returns 204 when authenticated.
// Unauthenticated requests are rejected by the RequireAuth middleware with 401.
func (h *AuthHandler) Check(c *gin.Context) {
// If we reached here, authentication succeeded in middleware
c.Status(http.StatusNoContent)
}
// Signup handles user registration requests
func (h *AuthHandler) Signup(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "signup")
defer observability.FinishSpan(span, nil)
// Check if signups are disabled
if h.config != nil && h.config.IsSignupDisabled() {
span.SetAttributes(attribute.Bool("auth.signups_disabled", true))
HandleAppError(c, contextutils.ErrForbidden)
return
}
span.SetAttributes(attribute.Bool("auth.signups_disabled", false))
var req api.UserCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, openapi_types.ErrValidationEmail) {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Set span attributes for request data
span.SetAttributes(
attribute.String("signup.username", req.Username),
attribute.Bool("signup.password_provided", req.Password != ""),
attribute.Bool("signup.email_provided", req.Email != nil && *req.Email != ""),
attribute.Bool("signup.language_provided", req.PreferredLanguage != nil && *req.PreferredLanguage != ""),
attribute.Bool("signup.level_provided", req.CurrentLevel != nil && *req.CurrentLevel != ""),
attribute.Bool("signup.timezone_provided", req.Timezone != nil && *req.Timezone != ""),
)
// Validate required fields
if req.Username == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
if req.Password == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
if req.Email == nil || *req.Email == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Validate username format (3-50 characters, alphanumeric + underscore)
if len(req.Username) < 3 || len(req.Username) > 50 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
usernameRegex := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
if !usernameRegex.MatchString(req.Username) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Validate password (minimum 8 characters)
if len(req.Password) < 8 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Validate email format (convert to string)
if !contextutils.IsValidEmail(string(*req.Email)) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Normalize email to lowercase
email := strings.ToLower(string(*req.Email))
h.logger.Info(c.Request.Context(), "Attempting signup for user", map[string]interface{}{"username": req.Username, "email": email})
// Check if username already exists
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking username uniqueness", err, map[string]interface{}{"username": req.Username})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if existingUser != nil {
span.SetAttributes(attribute.Bool("signup.username_exists", true))
HandleAppError(c, contextutils.ErrRecordExists)
return
}
// Check if email already exists
existingUserByEmail, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking email uniqueness", err, map[string]interface{}{"email": email})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if existingUserByEmail != nil {
span.SetAttributes(attribute.Bool("signup.email_exists", true))
HandleAppError(c, contextutils.ErrRecordExists)
return
}
// Set default values for optional fields
language := "italian" // Default to first language in the list
if h.config != nil {
// Get available languages from config
languages := h.config.GetLanguages()
if len(languages) > 0 {
language = languages[0]
}
}
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
language = *req.PreferredLanguage
}
// Choose canonical default level for the selected language (first level in config)
level := ""
levels := []string{}
if h.config != nil {
levels = h.config.GetLevelsForLanguage(language)
if len(levels) > 0 {
level = levels[0]
}
}
// If client provided a level, require it to be a canonical code for the language.
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
provided := *req.CurrentLevel
matched := false
for _, l := range levels {
if strings.EqualFold(l, provided) {
level = l
matched = true
break
}
}
if !matched {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
timezone := "UTC" // Default timezone
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
}
// Update span attributes with final values
span.SetAttributes(
attribute.String("signup.language", language),
attribute.String("signup.level", level),
attribute.String("signup.timezone", timezone),
)
// Create user with email and timezone (no AI settings)
user, err := h.userService.CreateUserWithEmailAndTimezone(c.Request.Context(), req.Username, email, timezone, language, level)
if err != nil {
h.logger.Error(c.Request.Context(), "Error creating user", err, map[string]interface{}{"username": req.Username, "email": email})
HandleAppError(c, contextutils.WrapError(err, "failed to create user account"))
return
}
// Now set the password hash
if err := h.userService.UpdateUserPassword(c.Request.Context(), user.ID, req.Password); err != nil {
h.logger.Error(c.Request.Context(), "Error setting user password", err, map[string]interface{}{"user_id": user.ID})
// Try to clean up the user we just created
if deleteErr := h.userService.DeleteUser(c.Request.Context(), user.ID); deleteErr != nil {
h.logger.Error(c.Request.Context(), "Error cleaning up user after password set failure", err, map[string]interface{}{"user_id": user.ID, "error": deleteErr.Error()})
}
HandleAppError(c, contextutils.WrapError(err, "failed to create user account"))
return
}
// Update span attributes with created user info
span.SetAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.username", user.Username),
attribute.String("user.email", email),
)
h.logger.Info(c.Request.Context(), "Successfully created user", map[string]interface{}{"username": req.Username, "user_id": user.ID})
// Return success response (no session created, no auto-login)
c.JSON(http.StatusCreated, api.SuccessResponse{
Success: true,
Message: stringPtr("Account created successfully. Please log in."),
})
}
// GoogleLogin initiates Google OAuth flow
func (h *AuthHandler) GoogleLogin(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "google_login")
defer observability.FinishSpan(span, nil)
// Generate a state parameter for security
state := generateRandomState()
// Get the redirect URI from query parameters
redirectURI := c.Query("redirect_uri")
// Set span attributes
span.SetAttributes(
attribute.String("oauth.provider", "google"),
attribute.String("oauth.state", state),
attribute.String("oauth.redirect_uri", redirectURI),
)
// Store state and redirect URI in session for verification
session := sessions.Default(c)
session.Set("oauth_state", state)
if redirectURI != "" {
session.Set("oauth_redirect_uri", redirectURI)
}
if err := session.Save(); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to save session"))
return
}
// Generate Google OAuth URL
authURL := h.oauthService.GetGoogleAuthURL(c.Request.Context(), state)
c.JSON(http.StatusOK, gin.H{
"auth_url": authURL,
})
}
// GoogleCallback handles the OAuth callback from Google
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "google_callback")
defer observability.FinishSpan(span, nil)
// Get the authorization code and state from query parameters
code := c.Query("code")
state := c.Query("state")
// Set span attributes
span.SetAttributes(
attribute.String("oauth.provider", "google"),
attribute.Bool("oauth.code_provided", code != ""),
attribute.String("oauth.state", state),
)
h.logger.Info(c.Request.Context(), "Google OAuth callback received", map[string]interface{}{"code": code, "state": state})
if code == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Verify state parameter for OAuth security (CSRF protection)
session := sessions.Default(c)
storedState := session.Get("oauth_state")
h.logger.Info(c.Request.Context(), "OAuth state verification", map[string]interface{}{"stored_state": storedState, "received_state": state})
// Enforce strict state verification for security
if storedState == nil {
h.logger.Error(c.Request.Context(), "No OAuth state found in session - possible CSRF attack or session issue", nil, map[string]interface{}{"state": state})
span.SetAttributes(attribute.Bool("oauth.state_valid", false))
HandleAppError(c, contextutils.ErrOAuthStateMismatch)
return
}
if storedState.(string) != state {
h.logger.Error(c.Request.Context(), "OAuth state mismatch - possible CSRF attack", nil, map[string]interface{}{"stored_state": storedState.(string), "received_state": state})
span.SetAttributes(attribute.Bool("oauth.state_valid", false))
HandleAppError(c, contextutils.ErrOAuthStateMismatch)
return
}
span.SetAttributes(attribute.Bool("oauth.state_valid", true))
h.logger.Info(c.Request.Context(), "OAuth state verification successful")
// Check if user is already authenticated (prevent duplicate callbacks)
existingUserID := session.Get(middleware.UserIDKey)
if existingUserID != nil {
h.logger.Info(c.Request.Context(), "User already authenticated during OAuth callback", map[string]interface{}{
"user_id": existingUserID.(int),
})
span.SetAttributes(attribute.Bool("oauth.duplicate_callback", true))
// Get user information for the response
user, err := h.userService.GetUserByID(c.Request.Context(), existingUserID.(int))
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting user by ID", err, map[string]interface{}{"user_id": existingUserID.(int)})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if user == nil {
h.logger.Error(c.Request.Context(), "User not found", nil, map[string]interface{}{"user_id": existingUserID.(int)})
HandleAppError(c, contextutils.ErrInternalError)
return
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
// Return success response for already authenticated user
response := api.LoginResponse{
Success: boolPtr(true),
Message: stringPtr("Already authenticated"),
User: &apiUser,
}
c.JSON(http.StatusOK, response)
return
}
// Get the stored redirect URI from session
storedRedirectURI := session.Get("oauth_redirect_uri")
var redirectURI string
if storedRedirectURI != nil {
redirectURI = storedRedirectURI.(string)
}
// Clear the state and redirect URI from session
session.Delete("oauth_state")
session.Delete("oauth_redirect_uri")
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Failed to save session", err, map[string]interface{}{"error": err.Error()})
HandleAppError(c, contextutils.WrapError(err, "failed to save session"))
return
}
// Authenticate user with Google OAuth
user, err := h.oauthService.AuthenticateGoogleUser(c.Request.Context(), code, h.userService)
if err != nil {
h.logger.Error(c.Request.Context(), "Google OAuth authentication failed", err, map[string]interface{}{"error": err.Error()})
// Check if this is a signup disabled error (structured)
if errors.Is(err, services.ErrSignupsDisabled) {
span.SetAttributes(attribute.Bool("oauth.signups_disabled", true))
HandleAppError(c, contextutils.ErrForbidden)
return
}
// Provide better error messages to the frontend using structured error checking
errorMessage := "Authentication failed"
if errors.Is(err, services.ErrOAuthCodeAlreadyUsed) {
errorMessage = "This authentication link has already been used. Please try signing in again."
} else if errors.Is(err, services.ErrOAuthClientConfig) {
errorMessage = "OAuth configuration error. Please contact support."
} else if errors.Is(err, services.ErrOAuthInvalidRequest) {
errorMessage = "Invalid authentication request. Please try again."
} else if errors.Is(err, services.ErrOAuthUnauthorized) {
errorMessage = "OAuth client is not authorized. Please contact support."
} else if errors.Is(err, services.ErrOAuthUnsupportedGrant) {
errorMessage = "Unsupported OAuth grant type. Please contact support."
}
HandleAppError(c, contextutils.WrapError(err, errorMessage))
return
}
// Update span attributes with user info
span.SetAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.username", user.Username),
attribute.Bool("user.email_provided", user.Email.Valid),
attribute.String("user.language", user.PreferredLanguage.String),
attribute.String("user.level", user.CurrentLevel.String),
attribute.Bool("user.is_new", user.CreatedAt.After(time.Now().Add(-5*time.Minute))), // Rough check if user was just created
)
// Update last active
if err := h.userService.UpdateLastActive(c.Request.Context(), user.ID); err != nil {
h.logger.Warn(c.Request.Context(), "Failed to update last active for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
}
// Create session
session.Set(middleware.UserIDKey, user.ID)
session.Set(middleware.UsernameKey, user.Username)
h.logger.Info(c.Request.Context(), "Setting session for user", map[string]interface{}{"user_id": user.ID, "username": user.Username})
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Failed to save session", err, map[string]interface{}{"error": err.Error()})
HandleAppError(c, contextutils.WrapError(err, "failed to create session"))
return
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
h.logger.Info(c.Request.Context(), "Google OAuth successful for user", map[string]interface{}{"username": user.Username, "user_id": user.ID})
// Return user info with redirect URI if available
response := api.LoginResponse{
Success: boolPtr(true),
Message: stringPtr("Google authentication successful"),
User: &apiUser,
}
// Add redirect URI to response if it was stored
if redirectURI != "" {
response.RedirectUri = &redirectURI
}
c.JSON(http.StatusOK, response)
}
// generateRandomState generates a cryptographically secure random state parameter for OAuth security
func generateRandomState() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 32)
// Use crypto/rand for cryptographically secure random generation
for i := range b {
// Generate a random byte and map it to charset
randomByte := make([]byte, 1)
if _, err := rand.Read(randomByte); err != nil {
// If crypto/rand fails, we have a serious system issue - don't fallback to weaker randomness
panic("Cryptographic random number generation failed: " + err.Error())
}
b[i] = charset[randomByte[0]%byte(len(charset))]
}
return string(b)
}
// SignupStatus returns whether signups are enabled or disabled
func (h *AuthHandler) SignupStatus(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "signup_status")
defer observability.FinishSpan(span, nil)
signupsDisabled := false
oauthWhitelistEnabled := false
var allowedDomains []string
var allowedEmails []string
if h.config != nil {
signupsDisabled = h.config.IsSignupDisabled()
if h.config.System != nil {
oauthWhitelistEnabled = len(h.config.System.Auth.AllowedDomains) > 0 || len(h.config.System.Auth.AllowedEmails) > 0
allowedDomains = h.config.System.Auth.AllowedDomains
allowedEmails = h.config.System.Auth.AllowedEmails
}
}
span.SetAttributes(
attribute.Bool("auth.signups_disabled", signupsDisabled),
attribute.Bool("auth.config_available", h.config != nil),
attribute.Bool("auth.oauth_whitelist_enabled", oauthWhitelistEnabled),
)
c.JSON(http.StatusOK, gin.H{
"signups_disabled": signupsDisabled,
"oauth_whitelist_enabled": oauthWhitelistEnabled,
"allowed_domains": allowedDomains,
"allowed_emails": allowedEmails,
})
}
package handlers
import (
"context"
"errors"
"quizapp/internal/middleware"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
var (
// ErrUnauthenticated indicates no current user could be determined
ErrUnauthenticated = errors.New("user not authenticated")
// ErrInvalidUserID indicates the stored user identifier is malformed
ErrInvalidUserID = errors.New("invalid user id")
// ErrForbidden indicates the user lacks permissions for the operation
ErrForbidden = errors.New("forbidden")
)
// GetCurrentUserID returns the current authenticated user's ID.
// It first checks the Gin context (set by RequireAuth/RequireAdmin),
// then falls back to the session store. Returns an error if unauthenticated
// or if the stored value is invalid.
func GetCurrentUserID(c *gin.Context) (int, error) {
if rawID, exists := c.Get(middleware.UserIDKey); exists {
if id, ok := rawID.(int); ok {
return id, nil
}
return 0, ErrInvalidUserID
}
// Fallback to session lookup if context not populated
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
if userID == nil {
return 0, ErrUnauthenticated
}
id, ok := userID.(int)
if !ok {
return 0, ErrInvalidUserID
}
return id, nil
}
// authzAdminChecker is the minimal capability needed from user service for admin checks.
// Any concrete user service that implements IsAdmin satisfies this interface.
type authzAdminChecker interface {
IsAdmin(ctx context.Context, userID int) (bool, error)
}
// RequireSelfOrAdmin permits the action if the current user is the target user
// or has admin privileges. Returns ErrForbidden when neither condition is met.
func RequireSelfOrAdmin(ctx context.Context, svc authzAdminChecker, currentID, targetID int) error {
if currentID == 0 {
return ErrUnauthenticated
}
if currentID == targetID {
return nil
}
isAdmin, err := svc.IsAdmin(ctx, currentID)
if err != nil {
return err
}
if !isAdmin {
return ErrForbidden
}
return nil
}
package handlers
import (
"context"
"encoding/json"
"time"
"quizapp/internal/api"
"quizapp/internal/models"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
openapi_types "github.com/oapi-codegen/runtime/types"
)
// Helper functions for pointer conversion
func stringPtr(s string) *string {
return &s
}
func boolPtr(b bool) *bool {
return &b
}
func int64Ptr(i int) *int64 {
i64 := int64(i)
return &i64
}
func float32Ptr(f float32) *float32 {
return &f
}
func intPtr(i int) *int {
return &i
}
func int64FromUint(u uint) *int64 {
i64 := int64(u)
return &i64
}
func timePtr(t time.Time) *time.Time {
return &t
}
// formatTimePtr formats a time.Time into an RFC3339 string pointer
func formatTimePtr(t time.Time) *string {
s := t.In(time.UTC).Format(time.RFC3339)
return &s
}
// formatTimePointer converts a *time.Time to *string (RFC3339) or nil
func formatTimePointer(tp *time.Time) *string {
if tp == nil {
return nil
}
s := tp.In(time.UTC).Format(time.RFC3339)
return &s
}
// formatTime formats a time.Time into an RFC3339 string
func formatTime(t time.Time) string {
return t.In(time.UTC).Format(time.RFC3339)
}
// Convert models.AuthAPIKey to api.APIKeySummary
func convertAuthAPIKeyToAPI(key *models.AuthAPIKey) api.APIKeySummary {
apiKey := api.APIKeySummary{}
// Scalars
if key.ID != 0 {
apiKey.Id = intPtr(key.ID)
}
if key.KeyName != "" {
apiKey.KeyName = stringPtr(key.KeyName)
}
if key.KeyPrefix != "" {
apiKey.KeyPrefix = stringPtr(key.KeyPrefix)
}
if key.PermissionLevel != "" {
pl := api.APIKeySummaryPermissionLevel(key.PermissionLevel)
apiKey.PermissionLevel = &pl
}
// Times
if !key.CreatedAt.IsZero() {
t := key.CreatedAt
apiKey.CreatedAt = &t
}
if !key.UpdatedAt.IsZero() {
t := key.UpdatedAt
apiKey.UpdatedAt = &t
}
if key.LastUsedAt.Valid {
t := key.LastUsedAt.Time
apiKey.LastUsedAt = &t
} else {
// Leave nil to represent null
apiKey.LastUsedAt = nil
}
return apiKey
}
// Convert slice of models.AuthAPIKey to []api.APIKeySummary
func convertAuthAPIKeysToAPI(keys []models.AuthAPIKey) []api.APIKeySummary {
if len(keys) == 0 {
return []api.APIKeySummary{}
}
out := make([]api.APIKeySummary, 0, len(keys))
for i := range keys {
out = append(out, convertAuthAPIKeyToAPI(&keys[i]))
}
return out
}
// Convert models.User to api.User
func convertUserToAPI(user *models.User) api.User {
apiUser := api.User{
Id: int64Ptr(user.ID),
Username: stringPtr(user.Username),
}
if !user.CreatedAt.IsZero() {
apiUser.CreatedAt = formatTimePtr(user.CreatedAt)
}
if user.LastActive.Valid {
apiUser.LastActive = formatTimePointer(&user.LastActive.Time)
}
if user.Email.Valid {
s := user.Email.String
apiUser.Email = &s
}
if user.Timezone.Valid {
s := user.Timezone.String
apiUser.Timezone = &s
}
if user.PreferredLanguage.Valid {
s := user.PreferredLanguage.String
apiUser.PreferredLanguage = &s
}
if user.CurrentLevel.Valid {
s := user.CurrentLevel.String
apiUser.CurrentLevel = &s
}
if user.AIProvider.Valid {
s := user.AIProvider.String
apiUser.AiProvider = &s
}
if user.AIModel.Valid {
s := user.AIModel.String
apiUser.AiModel = &s
}
if user.WordOfDayEmailEnabled.Valid {
enabled := user.WordOfDayEmailEnabled.Bool
apiUser.WordOfDayEmailEnabled = &enabled
}
// Always set ai_enabled as a boolean (never null)
aiEnabled := user.AIEnabled.Valid && user.AIEnabled.Bool
apiUser.AiEnabled = &aiEnabled
// For backwards compatibility, we'll set has_api_key to false here
// The proper check should be done using convertUserToAPIWithService
hasAPIKey := false
apiUser.HasApiKey = &hasAPIKey
// Include user roles if they exist
if len(user.Roles) > 0 {
apiRoles := make([]api.Role, len(user.Roles))
for i, role := range user.Roles {
apiRoles[i] = api.Role{
Id: int64(role.ID),
Name: role.Name,
Description: role.Description,
CreatedAt: formatTime(role.CreatedAt),
UpdatedAt: formatTime(role.UpdatedAt),
}
}
apiUser.Roles = &apiRoles
}
return apiUser
}
// convertUserToAPIWithService converts a models.User to api.User with proper API key checking
func convertUserToAPIWithService(ctx context.Context, user *models.User, userService services.UserServiceInterface) api.User {
apiUser := convertUserToAPI(user)
// Check if user has a valid API key for their current provider using the new table
hasAPIKey := false
if user.AIProvider.Valid && user.AIProvider.String != "" {
// Use the new per-provider API key system instead of the old user.AIAPIKey field
if userService != nil {
savedKey, err := userService.GetUserAPIKey(ctx, user.ID, user.AIProvider.String)
if err == nil && savedKey != "" {
// API key is available but not exposed in the API response for security
hasAPIKey = true
}
}
}
// If user doesn't have an AI provider set, hasAPIKey remains false (default)
apiUser.HasApiKey = &hasAPIKey
return apiUser
}
// Convert models.Question to api.Question
func convertQuestionToAPI(question *models.Question) api.Question {
apiQuestion := api.Question{
Id: int64Ptr(question.ID),
DifficultyScore: float32Ptr(float32(question.DifficultyScore)),
CorrectAnswer: intPtr(question.CorrectAnswer),
// UsageCount removed; use total_responses instead
}
if !question.CreatedAt.IsZero() {
v := formatTime(question.CreatedAt)
apiQuestion.CreatedAt = &v
}
if question.Type != "" {
qType := api.QuestionType(question.Type)
apiQuestion.Type = &qType
}
if question.Language != "" {
lang := api.Language(question.Language)
apiQuestion.Language = &lang
}
if question.Level != "" {
level := api.Level(question.Level)
apiQuestion.Level = &level
}
if question.Explanation != "" {
apiQuestion.Explanation = &question.Explanation
}
if question.Status != "" {
status := api.QuestionStatus(question.Status)
apiQuestion.Status = &status
}
// Convert content map to api.QuestionContent
if question.Content != nil {
content := &api.QuestionContent{}
if q, ok := question.Content["question"].(string); ok {
content.Question = q
}
if hint, ok := question.Content["hint"].(string); ok {
content.Hint = &hint
}
if passage, ok := question.Content["passage"].(string); ok {
content.Passage = &passage
}
if sentence, ok := question.Content["sentence"].(string); ok {
content.Sentence = &sentence
}
if opts, ok := question.Content["options"].([]interface{}); ok {
var options []string
for _, opt := range opts {
if o, ok := opt.(string); ok {
options = append(options, o)
}
}
if len(options) > 0 {
content.Options = options
}
}
apiQuestion.Content = content
}
// Add variety elements to the API response
if question.TopicCategory != "" {
apiQuestion.TopicCategory = &question.TopicCategory
}
if question.GrammarFocus != "" {
apiQuestion.GrammarFocus = &question.GrammarFocus
}
if question.VocabularyDomain != "" {
apiQuestion.VocabularyDomain = &question.VocabularyDomain
}
if question.Scenario != "" {
apiQuestion.Scenario = &question.Scenario
}
if question.StyleModifier != "" {
apiQuestion.StyleModifier = &question.StyleModifier
}
if question.DifficultyModifier != "" {
apiQuestion.DifficultyModifier = &question.DifficultyModifier
}
if question.TimeContext != "" {
apiQuestion.TimeContext = &question.TimeContext
}
return apiQuestion
}
// Convert services.QuestionWithStats to a JSON-compatible map using generated
// api.Question for fields, and include any additional fields the frontend
// expects (e.g., report_reasons) that are not present on the generated type.
func convertQuestionWithStatsToAPIMap(q *services.QuestionWithStats) map[string]interface{} {
apiQ := api.Question{}
if q != nil && q.Question != nil {
apiQ = convertQuestionToAPI(q.Question)
}
// Attach stats
if q != nil {
apiQ.CorrectCount = intPtr(q.CorrectCount)
apiQ.IncorrectCount = intPtr(q.IncorrectCount)
apiQ.TotalResponses = intPtr(q.TotalResponses)
apiQ.UserCount = intPtr(q.UserCount)
if q.Reporters != "" {
apiQ.Reporters = &q.Reporters
}
// ConfidenceLevel is optional
if q.ConfidenceLevel != nil {
apiQ.ConfidenceLevel = q.ConfidenceLevel
}
}
// Marshal to generic map so we can add fields not present in api.Question
m := map[string]interface{}{}
if b, err := json.Marshal(apiQ); err == nil {
_ = json.Unmarshal(b, &m)
}
// Add report_reasons if available on the service struct
if q != nil && q.ReportReasons != "" {
m["report_reasons"] = q.ReportReasons
}
return m
}
// Convert models.UserProgress to api.UserProgress
func convertUserProgressToAPI(ctx context.Context, progress *models.UserProgress, userID int, userLookup func(context.Context, int) (*models.User, error)) api.UserProgress {
apiProgress := api.UserProgress{
TotalQuestions: intPtr(progress.TotalQuestions),
CorrectAnswers: intPtr(progress.CorrectAnswers),
AccuracyRate: float32Ptr(float32(progress.AccuracyRate / 100.0)),
}
if progress.CurrentLevel != "" {
level := api.Level(progress.CurrentLevel)
apiProgress.CurrentLevel = &level
}
if progress.SuggestedLevel != "" {
level := api.Level(progress.SuggestedLevel)
apiProgress.SuggestedLevel = &level
}
if progress.WeakAreas != nil {
apiProgress.WeakAreas = &progress.WeakAreas
}
// Convert performance metrics
if progress.PerformanceByTopic != nil {
perfMap := make(map[string]api.PerformanceMetrics)
for topic, metrics := range progress.PerformanceByTopic {
if metrics != nil {
perfMap[topic] = api.PerformanceMetrics{
TotalAttempts: intPtr(metrics.TotalAttempts),
CorrectAttempts: intPtr(metrics.CorrectAttempts),
AverageResponseTimeMs: float32Ptr(float32(metrics.AverageResponseTimeMs)),
LastUpdated: func() *string {
if metrics.LastUpdated.IsZero() {
return nil
}
s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, metrics.LastUpdated, time.RFC3339, userLookup)
if err != nil || s == "" {
tmp := metrics.LastUpdated.In(time.UTC).Format(time.RFC3339)
return &tmp
}
return &s
}(),
}
}
}
apiProgress.PerformanceByTopic = &perfMap
}
// Convert recent activity
if progress.RecentActivity != nil {
var recentActivity []api.UserResponse
for _, activity := range progress.RecentActivity {
apiActivity := api.UserResponse{
QuestionId: int64Ptr(activity.QuestionID),
IsCorrect: &activity.IsCorrect,
}
if !activity.CreatedAt.IsZero() {
s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, activity.CreatedAt, time.RFC3339, userLookup)
if err != nil || s == "" {
tmp := activity.CreatedAt.In(time.UTC).Format(time.RFC3339)
apiActivity.CreatedAt = &tmp
} else {
apiActivity.CreatedAt = &s
}
}
recentActivity = append(recentActivity, apiActivity)
}
apiProgress.RecentActivity = &recentActivity
}
return apiProgress
}
// Convert models.DailyQuestionAssignmentWithQuestion to api.DailyQuestionWithDetails
func convertDailyAssignmentToAPI(ctx context.Context, assignment *models.DailyQuestionAssignmentWithQuestion, userID int, userLookup func(context.Context, int) (*models.User, error)) api.DailyQuestionWithDetails {
var completedAt *string
if assignment.CompletedAt.Valid {
if s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, assignment.CompletedAt.Time, time.RFC3339, userLookup); err == nil && s != "" {
completedAt = &s
} else {
tmp := assignment.CompletedAt.Time.In(time.UTC).Format(time.RFC3339)
completedAt = &tmp
}
}
apiQuestion := api.Question{}
if assignment.Question != nil {
apiQuestion = convertQuestionToAPI(assignment.Question)
// Override total_responses so UI 'Shown' reflects Daily-only impressions
if assignment.DailyShownCount > 0 {
apiQuestion.TotalResponses = &assignment.DailyShownCount
}
}
// AssignmentDate: produce date-only value (YYYY-MM-DD) using openapi_types.Date
ad := assignment.AssignmentDate
assignDate := openapi_types.Date{Time: ad}
// CreatedAt in user's timezone (with error-checked fallback)
var createdStr string
if s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, assignment.CreatedAt, time.RFC3339, userLookup); err == nil && s != "" {
createdStr = s
} else {
createdStr = assignment.CreatedAt.In(time.UTC).Format(time.RFC3339)
}
var submittedAt *string
if assignment.SubmittedAt != nil {
if s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, *assignment.SubmittedAt, time.RFC3339, userLookup); err == nil && s != "" {
submittedAt = &s
} else {
tmp := assignment.SubmittedAt.In(time.UTC).Format(time.RFC3339)
submittedAt = &tmp
}
}
result := api.DailyQuestionWithDetails{
Id: int64(assignment.ID),
UserId: int64(assignment.UserID),
QuestionId: int64(assignment.QuestionID),
AssignmentDate: assignDate,
IsCompleted: assignment.IsCompleted,
CompletedAt: completedAt,
CreatedAt: createdStr,
UserAnswerIndex: assignment.UserAnswerIndex,
SubmittedAt: submittedAt,
Question: apiQuestion,
}
// Attach per-user stats when available
if assignment.DailyShownCount >= 0 {
shown := int64(assignment.DailyShownCount)
result.UserShownCount = &shown
}
if assignment.UserTotalResponses >= 0 {
total := int64(assignment.UserTotalResponses)
result.UserTotalResponses = &total
}
if assignment.UserCorrectCount >= 0 {
cc := int64(assignment.UserCorrectCount)
result.UserCorrectCount = &cc
}
if assignment.UserIncorrectCount >= 0 {
ic := int64(assignment.UserIncorrectCount)
result.UserIncorrectCount = &ic
}
return result
}
// Convert slice of assignments
func convertDailyAssignmentsToAPI(ctx context.Context, assignments []*models.DailyQuestionAssignmentWithQuestion, userID int, userLookup func(context.Context, int) (*models.User, error)) []api.DailyQuestionWithDetails {
if len(assignments) == 0 {
return []api.DailyQuestionWithDetails{}
}
apiAssignments := make([]api.DailyQuestionWithDetails, len(assignments))
for i, a := range assignments {
apiAssignments[i] = convertDailyAssignmentToAPI(ctx, a, userID, userLookup)
}
return apiAssignments
}
// Convert models.DailyProgress to api.DailyProgress
func convertDailyProgressToAPI(progress *models.DailyProgress) api.DailyProgress {
return api.DailyProgress{
Date: openapi_types.Date{Time: progress.Date},
Completed: progress.Completed,
Total: progress.Total,
}
}
// Convert models.Story to api.Story
func convertStoryToAPI(story *models.Story) api.Story {
apiStory := api.Story{
Id: int64FromUint(story.ID),
UserId: int64FromUint(story.UserID),
Title: stringPtr(story.Title),
Language: stringPtr(story.Language),
Status: (*api.StoryStatus)(stringPtr(string(story.Status))),
}
if story.Subject != nil {
apiStory.Subject = story.Subject
}
if story.AuthorStyle != nil {
apiStory.AuthorStyle = story.AuthorStyle
}
if story.TimePeriod != nil {
apiStory.TimePeriod = story.TimePeriod
}
if story.Genre != nil {
apiStory.Genre = story.Genre
}
if story.Tone != nil {
apiStory.Tone = story.Tone
}
if story.CharacterNames != nil {
apiStory.CharacterNames = story.CharacterNames
}
if story.CustomInstructions != nil {
apiStory.CustomInstructions = story.CustomInstructions
}
// Handle enum field - only set if not nil (will be omitted from JSON due to omitempty)
if story.SectionLengthOverride != nil {
lengthOverride := api.StorySectionLengthOverride(*story.SectionLengthOverride)
apiStory.SectionLengthOverride = &lengthOverride
}
if !story.CreatedAt.IsZero() {
apiStory.CreatedAt = timePtr(story.CreatedAt)
}
if !story.UpdatedAt.IsZero() {
apiStory.UpdatedAt = timePtr(story.UpdatedAt)
}
if story.LastSectionGeneratedAt != nil {
apiStory.LastSectionGeneratedAt = timePtr(*story.LastSectionGeneratedAt)
}
return apiStory
}
// Convert models.StorySection to api.StorySection
func convertStorySectionToAPI(section *models.StorySection) api.StorySection {
apiSection := api.StorySection{
Id: int64FromUint(section.ID),
StoryId: int64FromUint(section.StoryID),
SectionNumber: intPtr(section.SectionNumber),
Content: stringPtr(section.Content),
LanguageLevel: stringPtr(section.LanguageLevel),
WordCount: intPtr(section.WordCount),
}
if !section.GeneratedAt.IsZero() {
apiSection.GeneratedAt = timePtr(section.GeneratedAt)
}
// Convert time.Time to openapi_types.Date for generation_date
if !section.GenerationDate.IsZero() {
generationDate := openapi_types.Date{Time: section.GenerationDate}
apiSection.GenerationDate = &generationDate
}
return apiSection
}
// Convert models.StoryWithSections to api.StoryWithSections
func convertStoryWithSectionsToAPI(story *models.StoryWithSections) api.StoryWithSections {
apiStory := api.StoryWithSections{
Id: int64FromUint(story.ID),
UserId: int64FromUint(story.UserID),
Title: stringPtr(story.Title),
Language: stringPtr(story.Language),
Status: (*api.StoryWithSectionsStatus)(stringPtr(string(story.Status))),
AutoGenerationPaused: boolPtr(story.AutoGenerationPaused),
}
if story.Subject != nil {
apiStory.Subject = story.Subject
}
if story.AuthorStyle != nil {
apiStory.AuthorStyle = story.AuthorStyle
}
if story.TimePeriod != nil {
apiStory.TimePeriod = story.TimePeriod
}
if story.Genre != nil {
apiStory.Genre = story.Genre
}
if story.Tone != nil {
apiStory.Tone = story.Tone
}
if story.CharacterNames != nil {
apiStory.CharacterNames = story.CharacterNames
}
if story.CustomInstructions != nil {
apiStory.CustomInstructions = story.CustomInstructions
}
// Handle enum field - only set if not nil (will be omitted from JSON due to omitempty)
if story.SectionLengthOverride != nil {
lengthOverride := api.StoryWithSectionsSectionLengthOverride(*story.SectionLengthOverride)
apiStory.SectionLengthOverride = &lengthOverride
}
if !story.CreatedAt.IsZero() {
apiStory.CreatedAt = timePtr(story.CreatedAt)
}
if !story.UpdatedAt.IsZero() {
apiStory.UpdatedAt = timePtr(story.UpdatedAt)
}
if story.LastSectionGeneratedAt != nil {
apiStory.LastSectionGeneratedAt = timePtr(*story.LastSectionGeneratedAt)
}
// Convert sections using the section conversion function
if len(story.Sections) > 0 {
apiSections := make([]api.StorySection, len(story.Sections))
for i, section := range story.Sections {
apiSections[i] = convertStorySectionToAPI(§ion)
}
apiStory.Sections = &apiSections
}
return apiStory
}
// Convert models.StorySectionWithQuestions to api.StorySectionWithQuestions
func convertStorySectionWithQuestionsToAPI(sectionWithQuestions *models.StorySectionWithQuestions) api.StorySectionWithQuestions {
apiSectionWithQuestions := api.StorySectionWithQuestions{
Id: int64FromUint(sectionWithQuestions.ID),
StoryId: int64FromUint(sectionWithQuestions.StoryID),
SectionNumber: intPtr(sectionWithQuestions.SectionNumber),
Content: stringPtr(sectionWithQuestions.Content),
LanguageLevel: stringPtr(sectionWithQuestions.LanguageLevel),
WordCount: intPtr(sectionWithQuestions.WordCount),
}
if !sectionWithQuestions.GeneratedAt.IsZero() {
apiSectionWithQuestions.GeneratedAt = timePtr(sectionWithQuestions.GeneratedAt)
}
// Convert time.Time to openapi_types.Date for generation_date
if !sectionWithQuestions.GenerationDate.IsZero() {
generationDate := openapi_types.Date{Time: sectionWithQuestions.GenerationDate}
apiSectionWithQuestions.GenerationDate = &generationDate
}
// Convert questions
if len(sectionWithQuestions.Questions) > 0 {
apiQuestions := make([]api.StorySectionQuestion, len(sectionWithQuestions.Questions))
for i, question := range sectionWithQuestions.Questions {
apiQuestions[i] = api.StorySectionQuestion{
Id: int64FromUint(question.ID),
SectionId: int64FromUint(question.SectionID),
QuestionText: stringPtr(question.QuestionText),
Options: &question.Options,
CorrectAnswerIndex: intPtr(question.CorrectAnswerIndex),
CreatedAt: timePtr(question.CreatedAt),
}
if question.Explanation != nil {
apiQuestions[i].Explanation = question.Explanation
}
}
apiSectionWithQuestions.Questions = &apiQuestions
}
return apiSectionWithQuestions
}
package handlers
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// DailyQuestionHandler handles daily question-related HTTP requests
type DailyQuestionHandler struct {
userService services.UserServiceInterface
dailyQuestionService services.DailyQuestionServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewDailyQuestionHandler creates a new DailyQuestionHandler
func NewDailyQuestionHandler(
userService services.UserServiceInterface,
dailyQuestionService services.DailyQuestionServiceInterface,
cfg *config.Config,
logger *observability.Logger,
) *DailyQuestionHandler {
return &DailyQuestionHandler{
userService: userService,
dailyQuestionService: dailyQuestionService,
cfg: cfg,
logger: logger,
}
}
// ParseDateInUserTimezone parses a date string in the user's timezone
func (h *DailyQuestionHandler) ParseDateInUserTimezone(ctx context.Context, userID int, dateStr string) (time.Time, string, error) {
// Delegate to shared util with injected user lookup
return contextutils.ParseDateInUserTimezone(ctx, userID, dateStr, h.userService.GetUserByID)
}
// GetDailyQuestions handles GET /v1/daily/questions/{date}
func (h *DailyQuestionHandler) GetDailyQuestions(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_daily_questions")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse date parameter
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.String("timezone", timezone),
)
// Get user to check current language preferences
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for language preference check", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Check if user has valid language and level preferences
if !user.PreferredLanguage.Valid || !user.CurrentLevel.Valid {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
currentLanguage := user.PreferredLanguage.String
currentLevel := user.CurrentLevel.String
// Get daily questions for the date
questions, err := h.dailyQuestionService.GetDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
"timezone": timezone,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get daily questions"))
return
}
// Check if existing questions match current language preferences
needsRegeneration := false
var oldLanguage, oldLevel string
if len(questions) == 0 {
// No questions exist, need to generate them
needsRegeneration = true
} else {
// Check if existing questions match current preferences
oldLanguage = questions[0].Question.Language
oldLevel = questions[0].Question.Level
for _, assignment := range questions {
if assignment.Question.Language != currentLanguage || assignment.Question.Level != currentLevel {
needsRegeneration = true
break
}
}
}
// If questions don't match current preferences, regenerate them
if needsRegeneration {
h.logger.Info(ctx, "Regenerating daily questions due to language preference change", map[string]interface{}{
"user_id": userID,
"date": dateStr,
"old_language": oldLanguage,
"old_level": oldLevel,
"new_language": currentLanguage,
"new_level": currentLevel,
})
// Regenerate daily questions with current preferences
err = h.dailyQuestionService.RegenerateDailyQuestions(ctx, userID, date)
if err != nil {
// Check if this is a "no questions available" error
if contextutils.IsError(err, contextutils.ErrNoQuestionsAvailable) {
h.logger.Warn(ctx, "No questions available in preferred language, keeping existing questions", map[string]interface{}{
"user_id": userID,
"date": dateStr,
"language": currentLanguage,
"level": currentLevel,
"error": err.Error(),
})
// Continue with existing questions rather than failing completely
} else {
h.logger.Error(ctx, "Failed to regenerate daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
// Continue with existing questions rather than failing completely
h.logger.Warn(ctx, "Continuing with existing questions due to regeneration failure", map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
}
} else {
// Get the regenerated questions
questions, err = h.dailyQuestionService.GetDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get regenerated daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get daily questions"))
return
}
}
}
// Convert to API types using shared converter
apiQuestions := convertDailyAssignmentsToAPI(ctx, questions, userID, h.userService.GetUserByID)
c.JSON(http.StatusOK, gin.H{
"questions": apiQuestions,
"date": dateStr,
})
}
// MarkQuestionCompleted handles POST /v1/daily/questions/{date}/complete/{questionId}
func (h *DailyQuestionHandler) MarkQuestionCompleted(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "mark_daily_question_completed")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse parameters
dateStr := c.Param("date")
questionIDStr := c.Param("questionId")
if dateStr == "" || questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.Int("question_id", questionID),
attribute.String("timezone", timezone),
)
// Mark question as completed
err = h.dailyQuestionService.MarkQuestionCompleted(ctx, userID, questionID, date)
if err != nil {
h.logger.Error(ctx, "Failed to mark daily question as completed", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
})
// Check if the error indicates no assignment was found
if contextutils.IsError(err, contextutils.ErrAssignmentNotFound) {
HandleAppError(c, contextutils.ErrAssignmentNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as completed"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{
Message: stringPtr("Question marked as completed"),
Success: true,
})
}
// ResetQuestionCompleted handles DELETE /v1/daily/questions/{date}/complete/{questionId}
func (h *DailyQuestionHandler) ResetQuestionCompleted(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "reset_daily_question_completed")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse parameters
dateStr := c.Param("date")
questionIDStr := c.Param("questionId")
if dateStr == "" || questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.Int("question_id", questionID),
attribute.String("timezone", timezone),
)
// Reset question completion status
err = h.dailyQuestionService.ResetQuestionCompleted(ctx, userID, questionID, date)
if err != nil {
h.logger.Error(ctx, "Failed to reset daily question completion", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
})
// Check if the error indicates no assignment was found
if contextutils.IsError(err, contextutils.ErrAssignmentNotFound) {
HandleAppError(c, contextutils.ErrAssignmentNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to reset question completion"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{
Message: stringPtr("Question completion reset"),
Success: true,
})
}
// GetAvailableDates handles GET /v1/daily/dates
func (h *DailyQuestionHandler) GetAvailableDates(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_daily_available_dates")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Add span attributes for observability
span.SetAttributes(observability.AttributeUserID(userID))
// Get available dates with assignments
dates, err := h.dailyQuestionService.GetAvailableDates(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get available dates", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get available dates"))
return
}
// Exclude future dates based on the user's timezone (clients expect local calendar days only)
user, _ := h.userService.GetUserByID(ctx, userID)
tz := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
tz = user.Timezone.String
}
loc, err := time.LoadLocation(tz)
if err != nil {
loc = time.UTC
}
now := time.Now().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Filter out dates that are after today in the user's timezone
var filtered []time.Time
for _, d := range dates {
// Treat the date value as a date-only value (time component ignored)
if !d.After(today) {
filtered = append(filtered, d)
}
}
// Convert dates to string format for JSON response
dateStrings := make([]string, len(filtered))
for i, date := range filtered {
dateStrings[i] = date.Format("2006-01-02")
}
c.JSON(http.StatusOK, gin.H{
"dates": dateStrings,
})
}
// Note: Daily question assignment is now handled automatically by the worker
// when sending daily reminder emails. No manual assignment endpoint needed.
// GetDailyProgress handles GET /v1/daily/progress/{date}
func (h *DailyQuestionHandler) GetDailyProgress(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_daily_progress")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse date parameter
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.String("timezone", timezone),
)
// Get daily progress for the date
progress, err := h.dailyQuestionService.GetDailyProgress(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get daily progress", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
"timezone": timezone,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get daily progress"))
return
}
// Convert to API type using shared converter
apiProgress := convertDailyProgressToAPI(progress)
c.JSON(http.StatusOK, apiProgress)
}
// SubmitDailyQuestionAnswer handles POST /v1/daily/questions/{date}/answer/{questionId}
func (h *DailyQuestionHandler) SubmitDailyQuestionAnswer(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "submit_daily_question_answer")
defer observability.FinishSpan(span, nil)
h.logger.Info(ctx, "SubmitDailyQuestionAnswer handler called", map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
"params": c.Params,
})
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse parameters
dateStr := c.Param("date")
questionIDStr := c.Param("questionId")
if dateStr == "" || questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Parse request body
var requestBody api.PostV1DailyQuestionsDateAnswerQuestionIdJSONBody
h.logger.Info(ctx, "Parsing request body", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
})
if err := c.ShouldBindJSON(&requestBody); err != nil {
h.logger.Error(ctx, "Failed to parse request body", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
"error": err.Error(),
})
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
h.logger.Info(ctx, "Request body parsed successfully",
map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
"user_answer_index": requestBody.UserAnswerIndex,
})
// Validate user answer index
if requestBody.UserAnswerIndex < 0 {
h.logger.Warn(ctx, "Invalid user answer index in SubmitDailyQuestionAnswer", map[string]interface{}{"user_answer_index": requestBody.UserAnswerIndex})
HandleAppError(c, contextutils.ErrInvalidAnswerIndex)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.Int("question_id", questionID),
attribute.String("timezone", timezone),
attribute.Int("user_answer_index", requestBody.UserAnswerIndex),
)
// Submit the answer
response, err := h.dailyQuestionService.SubmitDailyQuestionAnswer(
ctx,
userID,
questionID,
date,
requestBody.UserAnswerIndex,
)
if err != nil {
h.logger.Error(ctx, "Failed to submit daily question answer", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
"user_answer_index": requestBody.UserAnswerIndex,
})
// Check for specific error types
if contextutils.IsError(err, contextutils.ErrQuestionAlreadyAnswered) {
HandleAppError(c, contextutils.ErrQuestionAlreadyAnswered)
return
}
if contextutils.IsError(err, contextutils.ErrAssignmentNotFound) {
HandleAppError(c, contextutils.ErrAssignmentNotFound)
return
}
if contextutils.IsError(err, contextutils.ErrInvalidAnswerIndex) {
HandleAppError(c, contextutils.ErrInvalidAnswerIndex)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to submit answer"))
return
}
// Add completion status to response
responseWithCompletion := gin.H{
"user_answer_index": response.UserAnswerIndex,
"user_answer": response.UserAnswer,
"is_correct": response.IsCorrect,
"correct_answer_index": response.CorrectAnswerIndex,
"explanation": response.Explanation,
"is_completed": true,
}
c.JSON(http.StatusOK, responseWithCompletion)
}
// GetQuestionHistory handles GET /v1/daily/questions/{questionId}/history
func (h *DailyQuestionHandler) GetQuestionHistory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_question_history")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse question ID parameter
questionIDStr := c.Param("questionId")
if questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.Int("question_id", questionID),
)
// Get question history for the last 14 days
history, err := h.dailyQuestionService.GetQuestionHistory(ctx, userID, questionID, 14)
if err != nil {
h.logger.Error(ctx, "Failed to get question history", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get question history"))
return
}
// Determine user's timezone/location once, then filter out any future-dated assignments
user, _ := h.userService.GetUserByID(ctx, userID)
tz := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
tz = user.Timezone.String
}
loc, locErr := time.LoadLocation(tz)
if locErr != nil {
loc = time.UTC
}
now := time.Now().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Format times in user's timezone using helper, skipping future dates
resp := make([]map[string]interface{}, 0, len(history))
for _, he := range history {
// Skip future assignments in user's local date
ad := he.AssignmentDate.In(loc)
adDate := time.Date(ad.Year(), ad.Month(), ad.Day(), 0, 0, 0, 0, loc)
if adDate.After(today) {
continue
}
// Return assignment_date as date-only string (YYYY-MM-DD) using the stored UTC
// date to avoid timezone ambiguity for clients.
assignDateStr := he.AssignmentDate.UTC().Format("2006-01-02")
span.SetAttributes(attribute.String("assignment_date.formatted_with", "date_only"))
entry := map[string]interface{}{
"assignment_date": assignDateStr,
"is_completed": he.IsCompleted,
"is_correct": nil,
"submitted_at": nil,
}
if he.IsCorrect != nil {
entry["is_correct"] = *he.IsCorrect
}
if he.SubmittedAt != nil {
submittedStr, _, submittedErr := contextutils.FormatTimeInUserTimezone(ctx, userID, *he.SubmittedAt, time.RFC3339, h.userService.GetUserByID)
if submittedErr != nil || submittedStr == "" {
h.logger.Error(ctx, "Failed to format submitted_at in user's timezone", submittedErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"submitted_at_db": he.SubmittedAt,
})
span.RecordError(submittedErr, trace.WithStackTrace(true))
span.SetStatus(codes.Error, "failed to format submitted_at")
HandleAppError(c, contextutils.WrapError(submittedErr, "failed to format submitted_at"))
return
}
span.SetAttributes(attribute.String("submitted_at.formatted_with", "user_timezone"))
entry["submitted_at"] = submittedStr
}
resp = append(resp, entry)
}
c.JSON(http.StatusOK, gin.H{"history": resp})
}
package handlers
import (
"fmt"
"net/http"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// StandardizeHTTPError creates consistent HTTP error responses with structured error information
func StandardizeHTTPError(c *gin.Context, statusCode int, message, details string) {
// Map HTTP status code to appropriate error code
var errorCode contextutils.ErrorCode
var severity contextutils.SeverityLevel
switch statusCode {
case http.StatusBadRequest:
errorCode = contextutils.ErrorCodeInvalidInput
severity = contextutils.SeverityWarn
case http.StatusUnauthorized:
errorCode = contextutils.ErrorCodeUnauthorized
severity = contextutils.SeverityWarn
case http.StatusForbidden:
errorCode = contextutils.ErrorCodeForbidden
severity = contextutils.SeverityWarn
case http.StatusNotFound:
errorCode = contextutils.ErrorCodeRecordNotFound
severity = contextutils.SeverityInfo
case http.StatusConflict:
errorCode = contextutils.ErrorCodeRecordExists
severity = contextutils.SeverityInfo
case http.StatusServiceUnavailable:
errorCode = contextutils.ErrorCodeServiceUnavailable
severity = contextutils.SeverityError
default:
errorCode = contextutils.ErrorCodeInternalError
severity = contextutils.SeverityError
}
// Create an AppError with appropriate code
appErr := contextutils.NewAppError(
errorCode,
severity,
message,
details,
)
// Send response with the original status code
c.JSON(statusCode, appErr.ToJSON())
}
// StandardizeAppError sends a structured error response using AppError
func StandardizeAppError(c *gin.Context, err *contextutils.AppError) {
// Map error codes to HTTP status codes
statusCode := mapErrorCodeToHTTPStatus(err.Code)
// Convert error to JSON structure
errorJSON := err.ToJSON()
// Add retryable information based on error type
errorJSON["retryable"] = contextutils.IsRetryable(err)
c.JSON(statusCode, errorJSON)
}
// HandleValidationError handles input validation errors consistently
func HandleValidationError(c *gin.Context, field string, value interface{}, reason string) {
appErr := contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
fmt.Sprintf("Invalid %s", field),
fmt.Sprintf("Value '%v' is invalid: %s", value, reason),
)
StandardizeAppError(c, appErr)
}
// HandleAppError handles any AppError and sends appropriate HTTP response
func HandleAppError(c *gin.Context, err error) {
if appErr, ok := err.(*contextutils.AppError); ok {
// Special-case: no questions available should return 202 with GeneratingResponse body
if appErr.Code == contextutils.ErrorCodeNoQuestionsAvailable {
// 202 Accepted with generating payload (matches swagger GeneratingResponse)
c.JSON(http.StatusAccepted, gin.H{
"status": "generating",
"message": "No questions available. Please try again shortly.",
})
return
}
StandardizeAppError(c, appErr)
} else {
// Fallback for non-AppError types
StandardizeHTTPError(c, http.StatusInternalServerError, "Internal server error", err.Error())
}
}
// mapErrorCodeToHTTPStatus maps AppError codes to appropriate HTTP status codes
func mapErrorCodeToHTTPStatus(code contextutils.ErrorCode) int {
switch code {
case contextutils.ErrorCodeNoQuestionsAvailable:
return http.StatusAccepted
// 4xx Client Errors
case contextutils.ErrorCodeInvalidInput, contextutils.ErrorCodeMissingRequired,
contextutils.ErrorCodeInvalidFormat, contextutils.ErrorCodeValidationFailed,
contextutils.ErrorCodeOAuthStateMismatch:
return http.StatusBadRequest
case contextutils.ErrorCodeUnauthorized:
return http.StatusUnauthorized
case contextutils.ErrorCodeForbidden:
return http.StatusForbidden
case contextutils.ErrorCodeRecordNotFound, contextutils.ErrorCodeQuestionNotFound,
contextutils.ErrorCodeAssignmentNotFound:
return http.StatusNotFound
case contextutils.ErrorCodeRecordExists, contextutils.ErrorCodeGenerationLimitReached:
return http.StatusConflict
case contextutils.ErrorCodeSessionExpired, contextutils.ErrorCodeInvalidCredentials:
return http.StatusUnauthorized
case contextutils.ErrorCodeRateLimit:
return http.StatusTooManyRequests
// 5xx Server Errors
case contextutils.ErrorCodeInternalError:
return http.StatusInternalServerError
case contextutils.ErrorCodeServiceUnavailable, contextutils.ErrorCodeDatabaseConnection,
contextutils.ErrorCodeAIProviderUnavailable:
return http.StatusServiceUnavailable
case contextutils.ErrorCodeTimeout:
return http.StatusRequestTimeout
case contextutils.ErrorCodeDatabaseQuery, contextutils.ErrorCodeDatabaseTransaction,
contextutils.ErrorCodeForeignKeyViolation, contextutils.ErrorCodeTimestampMissingTimezone,
contextutils.ErrorCodeAIRequestFailed, contextutils.ErrorCodeAIResponseInvalid,
contextutils.ErrorCodeAIConfigInvalid, contextutils.ErrorCodeOAuthProviderError:
return http.StatusInternalServerError
// Default to internal server error for unknown codes
default:
return http.StatusInternalServerError
}
}
package handlers
import (
"database/sql"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
serviceinterfaces "quizapp/internal/serviceinterfaces"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
)
// FeedbackResponse represents the JSON response for feedback listing
type FeedbackResponse struct {
ID int `json:"id"`
UserID int `json:"user_id"`
FeedbackText string `json:"feedback_text"`
FeedbackType string `json:"feedback_type"`
ContextData map[string]interface{} `json:"context_data"`
ScreenshotData *string `json:"screenshot_data"`
ScreenshotURL *string `json:"screenshot_url"`
Status string `json:"status"`
AdminNotes *string `json:"admin_notes"`
AssignedToUserID *int32 `json:"assigned_to_user_id"`
ResolvedAt *string `json:"resolved_at"`
ResolvedByUserID *int32 `json:"resolved_by_user_id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// ensureContextDataNotNull returns an empty map if the input is nil
func ensureContextDataNotNull(data map[string]interface{}) map[string]interface{} {
if data == nil {
return map[string]interface{}{}
}
return data
}
// convertFeedbackToResponse converts FeedbackReport to FeedbackResponse
func convertFeedbackToResponse(fr models.FeedbackReport) FeedbackResponse {
response := FeedbackResponse{
ID: fr.ID,
UserID: fr.UserID,
FeedbackText: fr.FeedbackText,
FeedbackType: fr.FeedbackType,
ContextData: ensureContextDataNotNull(fr.ContextData),
Status: fr.Status,
CreatedAt: fr.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
UpdatedAt: fr.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
if fr.ScreenshotData.Valid {
response.ScreenshotData = &fr.ScreenshotData.String
}
if fr.ScreenshotURL.Valid {
response.ScreenshotURL = &fr.ScreenshotURL.String
}
if fr.AdminNotes.Valid {
response.AdminNotes = &fr.AdminNotes.String
}
if fr.AssignedToUserID.Valid {
response.AssignedToUserID = &fr.AssignedToUserID.Int32
}
if fr.ResolvedAt.Valid {
at := fr.ResolvedAt.Time.Format("2006-01-02T15:04:05Z07:00")
response.ResolvedAt = &at
}
if fr.ResolvedByUserID.Valid {
response.ResolvedByUserID = &fr.ResolvedByUserID.Int32
}
return response
}
// FeedbackHandler handles feedback report endpoints.
type FeedbackHandler struct {
feedbackService serviceinterfaces.FeedbackServiceInterface
linearService *services.LinearService
userService services.UserServiceInterface
config *config.Config
logger *observability.Logger
}
// NewFeedbackHandler creates a FeedbackHandler.
func NewFeedbackHandler(fs serviceinterfaces.FeedbackServiceInterface, linearService *services.LinearService, userService services.UserServiceInterface, cfg *config.Config, logger *observability.Logger) *FeedbackHandler {
return &FeedbackHandler{
feedbackService: fs,
linearService: linearService,
userService: userService,
config: cfg,
logger: logger,
}
}
// FeedbackSubmissionRequest represents a POST request.
type FeedbackSubmissionRequest struct {
FeedbackText string `json:"feedback_text" binding:"required"`
FeedbackType string `json:"feedback_type"`
ContextData map[string]interface{} `json:"context_data"`
ScreenshotData string `json:"screenshot_data"`
}
// SubmitFeedback handles POST /v1/feedback.
func (h *FeedbackHandler) SubmitFeedback(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "submit_feedback")
defer observability.FinishSpan(span, nil)
// Get user ID from Gin context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Add user ID to Go context for service layers
ctx = contextutils.WithUserID(ctx, userID)
var req FeedbackSubmissionRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
feedbackType := req.FeedbackType
if feedbackType == "" {
feedbackType = "general"
}
var screenshotData sql.NullString
if req.ScreenshotData != "" {
screenshotData = sql.NullString{String: req.ScreenshotData, Valid: true}
}
fr := &models.FeedbackReport{
UserID: userID,
FeedbackText: req.FeedbackText,
FeedbackType: feedbackType,
ContextData: req.ContextData,
ScreenshotData: screenshotData,
Status: "new",
}
created, err := h.feedbackService.CreateFeedback(ctx, fr)
if err != nil {
h.logger.Error(ctx, "create feedback failed", err, nil)
HandleAppError(c, err)
return
}
c.JSON(http.StatusCreated, convertFeedbackToResponse(*created))
}
// GetFeedback handles GET /v1/admin/backend/feedback/:id.
func (h *FeedbackHandler) GetFeedback(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_feedback")
defer observability.FinishSpan(span, nil)
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
feedback, err := h.feedbackService.GetFeedbackByID(ctx, id)
if err != nil {
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
h.logger.Error(ctx, "get feedback failed", err, nil)
HandleAppError(c, err)
return
}
c.JSON(http.StatusOK, convertFeedbackToResponse(*feedback))
}
// ListFeedback handles GET /v1/admin/feedback.
func (h *FeedbackHandler) ListFeedback(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "list_feedback")
defer observability.FinishSpan(span, nil)
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
status := c.Query("status")
feedbackType := c.Query("feedback_type")
userIDStr := c.Query("user_id")
var userID *int
if userIDStr != "" {
id, _ := strconv.Atoi(userIDStr)
userID = &id
}
list, total, err := h.feedbackService.GetFeedbackPaginated(ctx, page, pageSize, status, feedbackType, userID)
if err != nil {
h.logger.Error(ctx, "list feedback failed", err, nil)
HandleAppError(c, err)
return
}
// Convert each feedback item to response format
items := make([]FeedbackResponse, len(list))
for i, item := range list {
items[i] = convertFeedbackToResponse(item)
}
c.JSON(http.StatusOK, gin.H{"items": items, "total": total, "page": page, "page_size": pageSize})
}
// UpdateFeedback handles PATCH /v1/admin/feedback/:id.
func (h *FeedbackHandler) UpdateFeedback(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "update_feedback")
defer observability.FinishSpan(span, nil)
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
HandleAppError(c, contextutils.ErrorWithContextf("invalid feedback ID"))
return
}
var updates map[string]interface{}
if err := c.ShouldBindJSON(&updates); err != nil {
HandleAppError(c, contextutils.WrapError(err, "invalid request body"))
return
}
updated, err := h.feedbackService.UpdateFeedback(ctx, id, updates)
if err != nil {
h.logger.Error(ctx, "update feedback failed", err, nil)
HandleAppError(c, err)
return
}
c.JSON(http.StatusOK, convertFeedbackToResponse(*updated))
}
// DeleteFeedback handles DELETE /v1/admin/backend/feedback/:id.
func (h *FeedbackHandler) DeleteFeedback(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "delete_feedback")
defer observability.FinishSpan(span, nil)
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
HandleAppError(c, contextutils.ErrorWithContextf("invalid feedback ID"))
return
}
err = h.feedbackService.DeleteFeedback(ctx, id)
if err != nil {
h.logger.Error(ctx, "delete feedback failed", err, nil)
HandleAppError(c, err)
return
}
c.Status(http.StatusNoContent)
}
// DeleteFeedbackByStatus handles DELETE /v1/admin/backend/feedback?status=resolved.
func (h *FeedbackHandler) DeleteFeedbackByStatus(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "delete_feedback_by_status")
defer observability.FinishSpan(span, nil)
status := c.Query("status")
if status == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
count, err := h.feedbackService.DeleteFeedbackByStatus(ctx, status)
if err != nil {
h.logger.Error(ctx, "delete feedback by status failed", err, nil)
HandleAppError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"deleted_count": count})
}
// DeleteAllFeedback handles DELETE /v1/admin/backend/feedback?all=true.
func (h *FeedbackHandler) DeleteAllFeedback(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "delete_all_feedback")
defer observability.FinishSpan(span, nil)
count, err := h.feedbackService.DeleteAllFeedback(ctx)
if err != nil {
h.logger.Error(ctx, "delete all feedback failed", err, nil)
HandleAppError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"deleted_count": count})
}
// CreateLinearIssueResponse represents the response for creating a Linear issue
type CreateLinearIssueResponse struct {
IssueID string `json:"issue_id"`
IssueURL string `json:"issue_url"`
Title string `json:"title"`
}
// CreateLinearIssue handles POST /v1/admin/backend/feedback/:id/linear-issue.
func (h *FeedbackHandler) CreateLinearIssue(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "create_linear_issue")
defer observability.FinishSpan(span, nil)
if h.linearService == nil {
HandleAppError(c, contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
"Linear integration is not available",
"",
))
return
}
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Get feedback by ID
feedback, err := h.feedbackService.GetFeedbackByID(ctx, id)
if err != nil {
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
h.logger.Error(ctx, "get feedback failed", err, nil)
HandleAppError(c, err)
return
}
// Format title - only include feedback type and number
title := fmt.Sprintf("[Feedback #%d] %s", feedback.ID, getTypeLabel(feedback.FeedbackType))
// Get username and user for metadata
username := fmt.Sprintf("User %d", feedback.UserID)
var user *models.User
if h.userService != nil {
user, err = h.userService.GetUserByID(ctx, feedback.UserID)
if err == nil && user != nil {
username = user.Username
}
}
// Build description with feedback details
var descriptionBuilder strings.Builder
descriptionBuilder.WriteString(feedback.FeedbackText)
descriptionBuilder.WriteString("\n\n")
descriptionBuilder.WriteString("### Metadata\n\n")
descriptionBuilder.WriteString(fmt.Sprintf("- **Type**: %s\n", getTypeLabel(feedback.FeedbackType)))
descriptionBuilder.WriteString(fmt.Sprintf("- **Status**: %s\n", feedback.Status))
descriptionBuilder.WriteString(fmt.Sprintf("- **User ID**: %d\n", feedback.UserID))
descriptionBuilder.WriteString(fmt.Sprintf("- **Username**: %s\n", username))
descriptionBuilder.WriteString(fmt.Sprintf("- **Feedback ID**: %d\n", feedback.ID))
// Format created timestamp in user's timezone
createdFormatted := feedback.CreatedAt.Format("January 2, 2006 at 3:04 PM")
timezoneLabel := "UTC"
if h.userService != nil {
if formatted, tz, err := contextutils.FormatTimeInUserTimezone(ctx, feedback.UserID, feedback.CreatedAt, "January 2, 2006 at 3:04 PM", h.userService.GetUserByID); err == nil {
createdFormatted = formatted
timezoneLabel = tz
}
}
descriptionBuilder.WriteString(fmt.Sprintf("- **Created**: %s (%s)\n", createdFormatted, timezoneLabel))
if feedback.AdminNotes.Valid && feedback.AdminNotes.String != "" {
descriptionBuilder.WriteString(fmt.Sprintf("- **Admin Notes**: %s\n", feedback.AdminNotes.String))
}
// Add context data if available
if len(feedback.ContextData) > 0 {
descriptionBuilder.WriteString("\n### Context Data\n\n")
for key, value := range feedback.ContextData {
switch key {
case "page_url":
// Handle page_url specially - make it a full URL if it's a relative path
pageURL := fmt.Sprintf("%v", value)
if strings.HasPrefix(pageURL, "/") {
// It's a relative path, construct full URL
// Try to get base URL from config first
baseURL := ""
if h.config != nil && h.config.Server.AppBaseURL != "" {
baseURL = h.config.Server.AppBaseURL
}
// Fallback to request headers if config not available
if baseURL == "" {
baseURL = c.Request.Header.Get("Origin")
}
if baseURL == "" {
baseURL = c.Request.Header.Get("Referer")
if baseURL != "" {
// Extract base URL from referer (protocol + host)
// Find the first "/" after the protocol
if schemeIdx := strings.Index(baseURL, "://"); schemeIdx > 0 {
if pathIdx := strings.Index(baseURL[schemeIdx+3:], "/"); pathIdx > 0 {
baseURL = baseURL[:schemeIdx+3+pathIdx]
}
}
}
}
// Remove trailing slash if present
if baseURL != "" {
baseURL = strings.TrimSuffix(baseURL, "/")
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %s%s\n", key, baseURL, pageURL))
} else {
// If we can't determine base URL, just use the relative path
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %s\n", key, pageURL))
}
} else {
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %s\n", key, pageURL))
}
case "timestamp":
// Format timestamp as human readable in user's timezone
if tsStr, ok := value.(string); ok {
if ts, err := time.Parse(time.RFC3339, tsStr); err == nil {
// Convert to user's timezone
formatted := ts.Format("January 2, 2006 at 3:04 PM")
timezoneLabel := "UTC"
if h.userService != nil {
if fmtTime, tz, err := contextutils.FormatTimeInUserTimezone(ctx, feedback.UserID, ts, "January 2, 2006 at 3:04 PM", h.userService.GetUserByID); err == nil {
formatted = fmtTime
timezoneLabel = tz
}
}
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %s (%s)\n", key, formatted, timezoneLabel))
} else {
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %v\n", key, value))
}
} else {
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %v\n", key, value))
}
default:
descriptionBuilder.WriteString(fmt.Sprintf("- **%s**: %v\n", key, value))
}
}
}
// Add screenshot - embed as base64 data URI in markdown if available
if feedback.ScreenshotURL.Valid && feedback.ScreenshotURL.String != "" {
descriptionBuilder.WriteString("\n### Screenshot\n\n")
descriptionBuilder.WriteString(fmt.Sprintf("\n", feedback.ScreenshotURL.String))
} else if feedback.ScreenshotData.Valid && feedback.ScreenshotData.String != "" {
descriptionBuilder.WriteString("\n### Screenshot\n\n")
// Embed screenshot as base64 data URI
screenshotData := feedback.ScreenshotData.String
// Ensure it has the data URI prefix
if !strings.HasPrefix(screenshotData, "data:") {
screenshotData = "data:image/png;base64," + screenshotData
}
descriptionBuilder.WriteString(fmt.Sprintf("\n", screenshotData))
}
descriptionBuilder.WriteString("\n---\n*Created from Quiz Admin Feedback Reports*")
description := descriptionBuilder.String()
// Determine labels based on feedback type
var labels []string
switch feedback.FeedbackType {
case "bug":
labels = []string{"Bug"}
case "feature_request":
labels = []string{"Feature"}
case "improvement":
labels = []string{"Improvement"}
}
// Create Linear issue (use config defaults for team and project)
result, err := h.linearService.CreateIssue(ctx, title, description, "", "", labels, "")
if err != nil {
h.logger.Error(ctx, "create linear issue failed", err, nil)
HandleAppError(c, err)
return
}
response := CreateLinearIssueResponse{
IssueID: result.IssueID,
IssueURL: result.IssueURL,
Title: result.Title,
}
c.JSON(http.StatusOK, response)
}
// getTypeLabel converts feedback type to human-readable label
func getTypeLabel(feedbackType string) string {
switch feedbackType {
case "bug":
return "Bug Report"
case "feature_request":
return "Feature Request"
case "general":
return "General Feedback"
case "improvement":
return "Improvement"
default:
return feedbackType
}
}
package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// ParsePagination parses standard pagination query params from the request.
// It enforces bounds and applies defaults when values are missing or invalid.
func ParsePagination(c *gin.Context, defaultPage, defaultSize, maxSize int) (int, int) {
pageStr := c.DefaultQuery("page", strconv.Itoa(defaultPage))
sizeStr := c.DefaultQuery("page_size", strconv.Itoa(defaultSize))
page, err := strconv.Atoi(pageStr)
if err != nil || page < 1 {
page = defaultPage
}
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 1 {
size = defaultSize
}
if size > maxSize {
size = maxSize
}
return page, size
}
// ParseFilters returns a map of non-empty trimmed query params for the given keys.
func ParseFilters(c *gin.Context, keys ...string) map[string]string {
filters := make(map[string]string, len(keys))
for _, key := range keys {
if val := strings.TrimSpace(c.Query(key)); val != "" {
filters[key] = val
}
}
return filters
}
// WritePaginated standardizes paginated responses with a flexible items key, pagination block, and optional extras.
// It preserves existing API response shapes by allowing the caller to specify the items key.
func WritePaginated(c *gin.Context, itemsKey string, items, pagination any, extra gin.H) {
response := gin.H{
itemsKey: items,
"pagination": pagination,
}
for k, v := range extra {
response[k] = v
}
c.JSON(http.StatusOK, response)
}
package handlers
import (
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"quizapp/internal/config"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// QuizHandler handles quiz-related HTTP requests including questions and answers
type QuizHandler struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
learningService services.LearningServiceInterface
workerService services.WorkerServiceInterface
hintService services.GenerationHintServiceInterface
usageStatsSvc services.UsageStatsServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewQuizHandler creates a new QuizHandler
func NewQuizHandler(
userService services.UserServiceInterface,
questionService services.QuestionServiceInterface,
aiService services.AIServiceInterface,
learningService services.LearningServiceInterface,
workerService services.WorkerServiceInterface,
hintService services.GenerationHintServiceInterface,
usageStatsSvc services.UsageStatsServiceInterface,
config *config.Config,
logger *observability.Logger,
) *QuizHandler {
return &QuizHandler{
userService: userService,
questionService: questionService,
aiService: aiService,
learningService: learningService,
workerService: workerService,
hintService: hintService,
usageStatsSvc: usageStatsSvc,
cfg: config,
logger: logger,
}
}
// Deprecated: use GetUserIDFromSession in session.go
func (h *QuizHandler) getUserIDFromSession(c *gin.Context) (int, bool) {
return GetUserIDFromSession(c)
}
// GetQuestion handles requests for quiz questions
func (h *QuizHandler) GetQuestion(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_question")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Add span attributes for observability
span.SetAttributes(observability.AttributeUserID(userID))
// Check if a specific question ID is requested
questionIDStr := c.Param("id")
if questionIDStr != "" {
span.SetAttributes(attribute.String("question.id", questionIDStr))
h.getSpecificQuestion(c, userID, questionIDStr)
return
}
h.getNextQuestion(c, userID)
}
// getSpecificQuestion improves error handling with centralized utilities
func (h *QuizHandler) getSpecificQuestion(c *gin.Context, userID int, questionIDStr string) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_specific_question",
observability.AttributeUserID(userID),
attribute.String("question.id_str", questionIDStr),
)
defer observability.FinishSpan(span, nil)
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid question ID format",
"Question ID must be a valid integer",
err,
))
return
}
questionWithStats, err := h.questionService.GetQuestionWithStats(ctx, questionID)
if err != nil {
h.logger.Error(ctx, "Failed to get question with stats", err, map[string]interface{}{
"question_id": questionID,
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get question with stats"))
return
}
// Convert and hide sensitive information
apiQuestion := convertQuestionToAPI(questionWithStats.Question)
apiQuestion.Explanation = nil // Hide explanation
// Add response statistics to the API question
apiQuestion.CorrectCount = &questionWithStats.CorrectCount
apiQuestion.IncorrectCount = &questionWithStats.IncorrectCount
apiQuestion.TotalResponses = &questionWithStats.TotalResponses
// Get user-specific confidence level if available
confidenceLevel, err := h.learningService.GetUserQuestionConfidenceLevel(ctx, userID, questionID)
if err != nil {
h.logger.Warn(ctx, "Failed to get user confidence level", map[string]interface{}{
"error": err.Error(),
"question_id": questionID,
"user_id": userID,
})
// Don't fail the request, just continue without confidence level
} else if confidenceLevel != nil {
apiQuestion.ConfidenceLevel = confidenceLevel
}
c.JSON(http.StatusOK, apiQuestion)
}
// getNextQuestion improves error handling with centralized utilities
func (h *QuizHandler) getNextQuestion(c *gin.Context, userID int) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_next_question",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, nil)
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user by ID", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get user by ID"))
return
}
if user == nil {
span.SetAttributes(attribute.String("error.type", "user_nil"))
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check if user has required preferences set
if !user.PreferredLanguage.Valid || user.PreferredLanguage.String == "" {
span.SetAttributes(attribute.String("error.type", "missing_language_preference"))
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeMissingRequired,
contextutils.SeverityWarn,
"Language preference not set",
"Please set your preferred language in settings",
nil,
))
return
}
if !user.CurrentLevel.Valid || user.CurrentLevel.String == "" {
span.SetAttributes(attribute.String("error.type", "missing_level_preference"))
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeMissingRequired,
contextutils.SeverityWarn,
"Level preference not set",
"Please set your current level in settings",
nil,
))
return
}
language := c.DefaultQuery("language", user.PreferredLanguage.String)
level := c.DefaultQuery("level", user.CurrentLevel.String)
// Handle question type selection based on query parameters
var qType models.QuestionType
requestedTypes := c.Query("type")
strictTypeRequested := false
if requestedTypes != "" {
strictTypeRequested = true
types := strings.Split(requestedTypes, ",")
// Use the first valid type from the list
for _, t := range types {
if t = strings.TrimSpace(t); t != "" {
qType = models.QuestionType(t)
break
}
}
} else {
// Check if we need to exclude certain types (comma-separated list)
excludeTypes := c.Query("exclude_type")
if excludeTypes != "" {
excludeList := strings.Split(excludeTypes, ",")
var excludeSet []models.QuestionType
for _, t := range excludeList {
if t = strings.TrimSpace(t); t != "" {
excludeSet = append(excludeSet, models.QuestionType(t))
}
}
qType = h.selectRandomQuestionTypeExcluding(excludeSet...)
} else {
// Default random selection
qType = h.selectRandomQuestionType()
}
}
// Add span attributes for observability
span.SetAttributes(
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Bool("strict.type.requested", strictTypeRequested),
)
// Get next question with fallback logic
questionWithStats, err := h.questionService.GetNextQuestion(ctx, userID, language, level, qType)
if err != nil {
h.logger.Error(ctx, "Failed to get next question", err, map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
"question_type": string(qType),
})
// Fallback: try without question type if strict type was requested
if strictTypeRequested {
h.logger.Info(ctx, "Attempting fallback without question type", map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
})
questionWithStats, err = h.questionService.GetNextQuestion(ctx, userID, language, level, "")
if err != nil {
h.logger.Error(ctx, "Fallback also failed", err, map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
})
HandleAppError(c, contextutils.ErrNoQuestionsAvailable)
return
}
} else {
HandleAppError(c, contextutils.ErrNoQuestionsAvailable)
return
}
}
// Check if we got a valid question
if questionWithStats == nil || questionWithStats.Question == nil {
h.logger.Error(ctx, "GetNextQuestion returned nil question", nil, map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
"question_type": string(qType),
})
// If the user strictly requested a type, record a generation hint with short TTL
if strictTypeRequested && h.hintService != nil && qType != "" {
// Best-effort; do not fail the request if hint upsert fails
_ = h.hintService.UpsertHint(ctx, userID, language, level, qType, 10*time.Minute)
}
c.JSON(http.StatusAccepted, api.GeneratingResponse{
Status: stringPtr("generating"),
Message: stringPtr("No questions available. Prioritizing your requested question type. Please try again shortly."),
})
return
}
// Convert to API format and hide sensitive information
apiQuestion := convertQuestionToAPI(questionWithStats.Question)
apiQuestion.Explanation = nil // Hide explanation
// Add response statistics to the API question
apiQuestion.CorrectCount = &questionWithStats.CorrectCount
apiQuestion.IncorrectCount = &questionWithStats.IncorrectCount
apiQuestion.TotalResponses = &questionWithStats.TotalResponses
// Add confidence level if available
if questionWithStats.ConfidenceLevel != nil {
apiQuestion.ConfidenceLevel = questionWithStats.ConfidenceLevel
}
c.JSON(http.StatusOK, apiQuestion)
}
// SubmitAnswer improves error handling with centralized utilities
func (h *QuizHandler) SubmitAnswer(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "submit_answer")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req api.AnswerRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error(ctx, "Invalid answer request format", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
// Get the question
question, err := h.questionService.GetQuestionByID(ctx, int(req.QuestionId))
if err != nil {
h.logger.Error(ctx, "Failed to get question by ID", err, map[string]interface{}{
"question_id": req.QuestionId,
"user_id": userID,
})
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
// Check if answer is correct
isCorrect := int(req.UserAnswerIndex) == question.CorrectAnswer
// Record user response
responseTimeMs := 0
if req.ResponseTimeMs != nil {
responseTimeMs = int(*req.ResponseTimeMs)
}
// Use priority-aware recording to ensure priority scores are updated
// Store the user's answer index for future reference
if err := h.learningService.RecordAnswerWithPriority(ctx, userID, int(req.QuestionId), int(req.UserAnswerIndex), isCorrect, responseTimeMs); err != nil {
h.logger.Error(ctx, "Failed to record user response", err, map[string]interface{}{
"user_id": userID,
"question_id": req.QuestionId,
})
HandleAppError(c, contextutils.WrapError(err, "failed to record response"))
return
}
// Prepare response
// Get the user's answer text from the question options
userAnswerText := ""
if optionsRaw, ok := question.Content["options"]; ok {
if options, ok := optionsRaw.([]interface{}); ok {
if int(req.UserAnswerIndex) >= 0 && int(req.UserAnswerIndex) < len(options) {
if optStr, ok := options[int(req.UserAnswerIndex)].(string); ok {
userAnswerText = optStr
}
}
}
}
answerResponse := &api.AnswerResponse{
IsCorrect: &isCorrect,
UserAnswer: &userAnswerText,
UserAnswerIndex: &req.UserAnswerIndex,
Explanation: &question.Explanation,
CorrectAnswerIndex: &question.CorrectAnswer,
}
c.JSON(http.StatusOK, answerResponse)
// Add span attributes for observability
span.SetAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", int(req.QuestionId)),
attribute.Bool("answer.is_correct", isCorrect),
attribute.Int("response.time_ms", responseTimeMs),
)
}
// GetProgress improves error handling with centralized utilities
func (h *QuizHandler) GetProgress(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_progress")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(observability.AttributeUserID(userID))
progress, err := h.learningService.GetUserProgress(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user progress", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get progress"))
return
}
// Get worker status information
workerStatus, err := h.getWorkerStatusForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get worker status for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get learning preferences
learningPrefs, err := h.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get learning preferences for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get priority insights
priorityInsights, err := h.getPriorityInsightsForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get priority insights for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get generation focus information
generationFocus, err := h.getGenerationFocusForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get generation focus for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get high priority topics
highPriorityTopics, err := h.getHighPriorityTopicsForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get high priority topics for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get gap analysis
gapAnalysis, err := h.getGapAnalysisForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get gap analysis for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get priority distribution
priorityDistribution, err := h.getPriorityDistributionForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get priority distribution for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Convert models.UserProgress to api.UserProgress
apiProgress := convertUserProgressToAPI(ctx, progress, userID, h.userService.GetUserByID)
// Add worker-related information
if workerStatus != nil {
apiProgress.WorkerStatus = workerStatus
}
if learningPrefs != nil {
apiProgress.LearningPreferences = convertLearningPreferencesToAPI(learningPrefs)
}
if priorityInsights != nil {
apiProgress.PriorityInsights = priorityInsights
}
if generationFocus != nil {
apiProgress.GenerationFocus = generationFocus
}
if highPriorityTopics != nil {
apiProgress.HighPriorityTopics = &highPriorityTopics
}
if gapAnalysis != nil {
apiProgress.GapAnalysis = &gapAnalysis
}
if priorityDistribution != nil {
apiProgress.PriorityDistribution = &priorityDistribution
}
c.JSON(http.StatusOK, apiProgress)
}
// GetAITokenUsage returns AI token usage statistics for the authenticated user
func (h *QuizHandler) GetAITokenUsage(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_token_usage")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
span.SetAttributes(attribute.String("error", "no_user_session"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrUnauthorized, "user not authenticated"))
return
}
span.SetAttributes(observability.AttributeUserID(userID))
startDateStr := c.Query("startDate")
if startDateStr == "" {
span.SetAttributes(attribute.String("error", "missing_start_date"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrInvalidInput, "startDate parameter is required"))
return
}
endDateStr := c.Query("endDate")
if endDateStr == "" {
span.SetAttributes(attribute.String("error", "missing_end_date"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrInvalidInput, "endDate parameter is required"))
return
}
startDate, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
span.SetAttributes(attribute.String("error", "invalid_start_date"))
HandleAppError(c, contextutils.WrapErrorf(contextutils.ErrInvalidInput, "invalid startDate format: %v", err))
return
}
endDate, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
span.SetAttributes(attribute.String("error", "invalid_end_date"))
HandleAppError(c, contextutils.WrapErrorf(contextutils.ErrInvalidInput, "invalid endDate format: %v", err))
return
}
// Get usage stats
stats, err := h.usageStatsSvc.GetUserAITokenUsageStats(ctx, userID, startDate, endDate)
if err != nil {
h.logger.Error(ctx, "Failed to get user AI token usage stats", err, map[string]any{
"user_id": userID,
"start_date": startDateStr,
"end_date": endDateStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get AI token usage stats"))
return
}
c.JSON(http.StatusOK, stats)
}
// GetAITokenUsageDaily returns daily aggregated AI token usage for the authenticated user
func (h *QuizHandler) GetAITokenUsageDaily(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_token_usage_daily")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
span.SetAttributes(attribute.String("error", "no_user_session"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrUnauthorized, "user not authenticated"))
return
}
span.SetAttributes(observability.AttributeUserID(userID))
startDateStr := c.Query("startDate")
if startDateStr == "" {
span.SetAttributes(attribute.String("error", "missing_start_date"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrInvalidInput, "startDate parameter is required"))
return
}
endDateStr := c.Query("endDate")
if endDateStr == "" {
span.SetAttributes(attribute.String("error", "missing_end_date"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrInvalidInput, "endDate parameter is required"))
return
}
startDate, err := time.Parse("2006-01-02", startDateStr)
if err != nil {
span.SetAttributes(attribute.String("error", "invalid_start_date"))
HandleAppError(c, contextutils.WrapErrorf(contextutils.ErrInvalidInput, "invalid startDate format: %v", err))
return
}
endDate, err := time.Parse("2006-01-02", endDateStr)
if err != nil {
span.SetAttributes(attribute.String("error", "invalid_end_date"))
HandleAppError(c, contextutils.WrapErrorf(contextutils.ErrInvalidInput, "invalid endDate format: %v", err))
return
}
// Get daily usage stats
stats, err := h.usageStatsSvc.GetUserAITokenUsageStatsByDay(ctx, userID, startDate, endDate)
if err != nil {
h.logger.Error(ctx, "Failed to get user AI token usage stats by day", err, map[string]interface{}{
"user_id": userID,
"start_date": startDateStr,
"end_date": endDateStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get daily AI token usage stats"))
return
}
c.JSON(http.StatusOK, stats)
}
// GetAITokenUsageHourly returns hourly aggregated AI token usage for the authenticated user on a specific day
func (h *QuizHandler) GetAITokenUsageHourly(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_token_usage_hourly")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
span.SetAttributes(attribute.String("error", "no_user_session"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrUnauthorized, "user not authenticated"))
return
}
span.SetAttributes(observability.AttributeUserID(userID))
dateStr := c.Query("date")
if dateStr == "" {
span.SetAttributes(attribute.String("error", "missing_date"))
HandleAppError(c, contextutils.WrapError(contextutils.ErrInvalidInput, "date parameter is required"))
return
}
date, err := time.Parse("2006-01-02", dateStr)
if err != nil {
span.SetAttributes(attribute.String("error", "invalid_date"))
HandleAppError(c, contextutils.WrapErrorf(contextutils.ErrInvalidInput, "invalid date format: %v", err))
return
}
// Get hourly usage stats
stats, err := h.usageStatsSvc.GetUserAITokenUsageStatsByHour(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get user AI token usage stats by hour", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get hourly AI token usage stats"))
return
}
c.JSON(http.StatusOK, stats)
}
// ReportQuestion improves error handling with centralized utilities
func (h *QuizHandler) ReportQuestion(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "report_question")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleValidationError(c, "question_id", questionIDStr, "must be a valid integer")
return
}
// Parse request body for report reason
var req struct {
ReportReason *string `json:"report_reason"`
}
// Bind JSON if present (optional)
if err := c.ShouldBindJSON(&req); err != nil {
// Ignore binding errors for optional request body
req.ReportReason = nil
}
// Get report reason, default to empty string if not provided
reportReason := ""
if req.ReportReason != nil {
reportReason = *req.ReportReason
}
span.SetAttributes(
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
err = h.questionService.ReportQuestion(ctx, questionID, userID, reportReason)
if err != nil {
h.logger.Error(ctx, "Failed to report question", err, map[string]interface{}{
"question_id": questionID,
"user_id": userID,
})
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to report question"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{Success: true, Message: stringPtr("Question reported successfully")})
}
// MarkQuestionAsKnown improves error handling with centralized utilities
func (h *QuizHandler) MarkQuestionAsKnown(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "mark_question_as_known")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleValidationError(c, "question_id", questionIDStr, "must be a valid integer")
return
}
// Optional: Parse confidence level from request body
var req struct {
ConfidenceLevel *int `json:"confidence_level"`
}
// Bind JSON if present (optional)
if err := c.ShouldBindJSON(&req); err != nil {
// Ignore binding errors for optional request body
req.ConfidenceLevel = nil
}
span.SetAttributes(
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
// Mark question as known with confidence level
err = h.learningService.MarkQuestionAsKnown(ctx, userID, questionID, req.ConfidenceLevel)
if err != nil {
h.logger.Error(ctx, "Failed to mark question as known for user", err, map[string]interface{}{
"question_id": questionID,
"user_id": userID,
})
if contextutils.IsError(err, contextutils.ErrQuestionNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as known"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{Success: true, Message: stringPtr("Question marked as known successfully")})
}
// ChatStream handles requests for AI-powered streaming chat about a question
func (h *QuizHandler) ChatStream(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "chat_stream")
defer observability.FinishSpan(span, nil)
userID, exists := h.getUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req api.QuizChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil || user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("ai.provider", user.AIProvider.String),
attribute.String("ai.model", user.AIModel.String),
)
// Prepare the request for the AI service
aiReq := &models.AIChatRequest{
Language: string(*req.Question.Language),
Level: string(*req.Question.Level),
QuestionType: models.QuestionType(*req.Question.Type),
UserMessage: req.UserMessage,
}
if req.Question.Content != nil {
aiReq.Question = req.Question.Content.Question
aiReq.Options = req.Question.Content.Options
if req.Question.Content.Passage != nil {
aiReq.Passage = *req.Question.Content.Passage
}
// For vocabulary questions, use the sentence field as the passage
if req.Question.Content.Sentence != nil && req.Question.Type != nil && *req.Question.Type == "vocabulary" {
aiReq.Passage = *req.Question.Content.Sentence
}
}
if req.AnswerContext != nil {
if req.AnswerContext.UserAnswer != nil {
aiReq.UserAnswer = *req.AnswerContext.UserAnswer
}
if req.AnswerContext.IsCorrect != nil {
aiReq.IsCorrect = req.AnswerContext.IsCorrect
}
}
// Include conversation history if provided
if req.ConversationHistory != nil {
aiReq.ConversationHistory = make([]models.ChatMessage, len(*req.ConversationHistory))
for i, msg := range *req.ConversationHistory {
// Extract text content from the object
contentText := ""
if msg.Content.Text != nil {
contentText = *msg.Content.Text
}
aiReq.ConversationHistory[i] = models.ChatMessage{
Role: msg.Role,
Content: contentText,
}
}
}
// Create user AI configuration
userConfig := &models.UserAIConfig{
Provider: "", // will be set from user settings
Model: "", // use service default
APIKey: "",
Username: user.Username,
}
if user.AIProvider.Valid && user.AIProvider.String != "" {
userConfig.Provider = user.AIProvider.String
}
if user.AIModel.Valid && user.AIModel.String != "" {
userConfig.Model = user.AIModel.String
}
// Use the new per-provider API key system instead of the old user.AIAPIKey field
var apiKeyID *int
if userConfig.Provider != "" {
savedKey, keyID, err := h.userService.GetUserAPIKeyWithID(c.Request.Context(), userID, userConfig.Provider)
if err == nil && savedKey != "" {
userConfig.APIKey = savedKey
apiKeyID = keyID
}
}
// Set up Server-Sent Events headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "Cache-Control")
// Create a channel for streaming chunks
chunks := make(chan string, 10)
// Use the request context to detect client disconnect
reqCtx := c.Request.Context()
// Create a timeout context, but also watch for client disconnect
timeoutCtx, cancel := context.WithTimeout(reqCtx, config.QuizStreamTimeout)
defer cancel()
// Combine both contexts - cancel if either times out or client disconnects
ctx, combinedCancel := context.WithCancel(timeoutCtx)
defer combinedCancel()
// Store userID and apiKeyID in context for usage tracking
// This context will be used by the AI service for usage tracking
ctx = contextutils.WithUserID(ctx, userID)
if apiKeyID != nil {
ctx = contextutils.WithAPIKeyID(ctx, *apiKeyID)
}
// Watch for client disconnect
go func() {
defer func() {
if r := recover(); r != nil {
h.logger.Error(ctx, "Panic in client disconnect watcher", nil, map[string]any{
"panic": r,
})
}
}()
select {
case <-reqCtx.Done():
combinedCancel() // Cancel if client disconnects
case <-ctx.Done():
// Context already cancelled
}
}()
// Start the AI streaming in a goroutine
go func() {
defer func() {
if r := recover(); r != nil {
h.logger.Error(ctx, "Panic in AI streaming goroutine", nil, map[string]interface{}{
"panic": r,
})
}
close(chunks) // Close the channel when the goroutine completes
}()
if err := h.aiService.GenerateChatResponseStream(ctx, userConfig, aiReq, chunks); err != nil {
h.logger.Error(ctx, "AI chat streaming failed for user", err, map[string]interface{}{
"user_id": contextutils.GetUserIDFromContext(ctx),
})
// Only send error if context is not cancelled (avoid sending to closed channel)
if ctx.Err() == nil {
select {
case chunks <- fmt.Sprintf("ERROR: %v", err):
default:
// Channel full, skip sending error
}
}
}
}()
// Stream the response chunks
c.Stream(func(w io.Writer) bool {
select {
case chunk, ok := <-chunks:
if !ok {
// Channel closed, end streaming
return false
}
// Handle error messages
if strings.HasPrefix(chunk, "ERROR: ") {
c.SSEvent("error", chunk[7:]) // Remove "ERROR: " prefix
return false
}
// Marshal the chunk to JSON to ensure newlines and special characters are preserved.
jsonChunk, err := json.Marshal(chunk)
if err != nil {
h.logger.Error(ctx, "Failed to marshal chat stream chunk to JSON", err)
return true // Continue streaming, skip this chunk
}
// Send normal content chunk in proper SSE format
if _, err := fmt.Fprintf(w, "data: %s\n\n", jsonChunk); err != nil {
h.logger.Error(ctx, "Failed to write chat stream data", err)
return false
}
c.Writer.Flush()
return true
case <-ctx.Done():
c.SSEvent("error", "Request timeout")
return false
}
})
}
// Helper methods
func (h *QuizHandler) selectRandomQuestionType() models.QuestionType {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
types := []models.QuestionType{
models.Vocabulary,
models.FillInBlank,
models.QuestionAnswer,
models.ReadingComprehension,
}
return types[rand.Intn(len(types))]
}
// selectRandomQuestionTypeExcluding returns a random question type excluding the specified types
func (h *QuizHandler) selectRandomQuestionTypeExcluding(excludeTypes ...models.QuestionType) models.QuestionType {
availableTypes := []models.QuestionType{
models.Vocabulary,
models.FillInBlank,
models.QuestionAnswer,
models.ReadingComprehension,
}
// Filter out excluded types
for _, excludeType := range excludeTypes {
for i, availableType := range availableTypes {
if availableType == excludeType {
availableTypes = append(availableTypes[:i], availableTypes[i+1:]...)
break
}
}
}
if len(availableTypes) == 0 {
return models.Vocabulary // Default fallback
}
return availableTypes[rand.Intn(len(availableTypes))]
}
// GetWorkerStatus returns worker status and error information for the current user
func (h *QuizHandler) GetWorkerStatus(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_status")
defer observability.FinishSpan(span, nil)
userID, exists := h.getUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(observability.AttributeUserID(userID))
// Get worker health information
workerHealth, err := h.workerService.GetWorkerHealth(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to get worker health", err)
HandleAppError(c, contextutils.WrapError(err, "failed to get worker status"))
return
}
// Check if user is paused
userPaused, err := h.workerService.IsUserPaused(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to check user pause status", err, nil)
userPaused = false // Default to not paused if check fails
}
// Check if global pause is active
globalPaused, err := h.workerService.IsGlobalPaused(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to check global pause status", err, nil)
globalPaused = false // Default to not paused if check fails
}
// Extract relevant information for the user
response := gin.H{
"has_errors": false,
"error_message": "",
"global_paused": globalPaused,
"user_paused": userPaused,
"healthy_workers": workerHealth["healthy_count"],
"total_workers": workerHealth["total_count"],
"last_error_details": "",
"worker_running": false,
}
// Check for worker errors
if workerInstances, ok := workerHealth["worker_instances"].([]map[string]interface{}); ok {
for _, instance := range workerInstances {
if lastError, hasError := instance["last_run_error"]; hasError && lastError != nil {
// Only handle string type
if errorStr, ok := lastError.(string); ok && errorStr != "" {
response["has_errors"] = true
response["error_message"] = "Worker encountered errors during question generation"
response["last_error_details"] = errorStr
break
}
}
if isRunning, ok := instance["is_running"].(bool); ok && isRunning {
response["worker_running"] = true
}
}
}
c.JSON(http.StatusOK, response)
}
// Helper functions for enhanced progress information
func (h *QuizHandler) getWorkerStatusForUser(ctx context.Context, userID int) (*api.WorkerStatus, error) {
// Get worker health information
workerHealth, err := h.workerService.GetWorkerHealth(ctx)
if err != nil {
return nil, err
}
// Check if user is paused
userPaused, err := h.workerService.IsUserPaused(ctx, userID)
if err != nil {
userPaused = false // Default to not paused if check fails
}
// Check if global pause is active
globalPaused, err := h.workerService.IsGlobalPaused(ctx)
if err != nil {
globalPaused = false // Default to not paused if check fails
}
// Determine worker status
var status api.WorkerStatusStatus
var errorMessage *string
if globalPaused {
status = api.Idle // Use idle for paused state
} else if userPaused {
status = api.Idle // Use idle for paused state
} else {
status = api.Idle // Default to idle
// Check for worker errors and actual activity
if workerInstances, ok := workerHealth["worker_instances"].([]map[string]interface{}); ok {
for _, instance := range workerInstances {
// Check for errors first
if lastError, hasError := instance["last_run_error"]; hasError && lastError != nil {
if errorStr, ok := lastError.(string); ok && errorStr != "" {
// For errors, we'll use idle status but set the error message
status = api.Idle
errorMessage = &errorStr
break
}
}
// Check if worker is running AND has recent activity
if isRunning, ok := instance["is_running"].(bool); ok && isRunning {
// Only set to busy if the worker is actually active (not just running but idle)
// We'll check if there's recent activity or if the worker is actively generating
if lastHeartbeat, hasHeartbeat := instance["last_heartbeat"]; hasHeartbeat && lastHeartbeat != nil {
if heartbeatStr, ok := lastHeartbeat.(string); ok {
if heartbeat, err := time.Parse(time.RFC3339, heartbeatStr); err == nil {
// Consider busy if heartbeat is very recent (within last 30 seconds)
if time.Since(heartbeat) < 30*time.Second {
status = api.Busy
}
}
}
}
}
}
}
}
// Get last heartbeat
var lastHeartbeat *time.Time
if workerInstances, ok := workerHealth["worker_instances"].([]map[string]interface{}); ok && len(workerInstances) > 0 {
if heartbeatStr, ok := workerInstances[0]["last_heartbeat"].(string); ok {
if heartbeat, err := time.Parse(time.RFC3339, heartbeatStr); err == nil {
lastHeartbeat = &heartbeat
}
}
}
return &api.WorkerStatus{
Status: &status,
LastHeartbeat: formatTimePointer(lastHeartbeat),
ErrorMessage: errorMessage,
}, nil
}
func (h *QuizHandler) getPriorityInsightsForUser(ctx context.Context, userID int) (*api.PriorityInsights, error) {
// Get priority distribution for the user
priorityDistribution, err := h.learningService.GetUserPriorityScoreDistribution(ctx, userID)
if err != nil {
return nil, err
}
// Extract counts from distribution
highCount := 0
mediumCount := 0
lowCount := 0
totalCount := 0
if high, ok := priorityDistribution["high"].(int); ok {
highCount = high
totalCount += high
}
if medium, ok := priorityDistribution["medium"].(int); ok {
mediumCount = medium
totalCount += medium
}
if low, ok := priorityDistribution["low"].(int); ok {
lowCount = low
totalCount += low
}
return &api.PriorityInsights{
TotalQuestionsInQueue: &totalCount,
HighPriorityQuestions: &highCount,
MediumPriorityQuestions: &mediumCount,
LowPriorityQuestions: &lowCount,
}, nil
}
func (h *QuizHandler) getGenerationFocusForUser(ctx context.Context, userID int) (*api.GenerationFocus, error) {
// Get user's AI configuration
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
// Get current generation model
model := "default"
if user.AIModel.Valid && user.AIModel.String != "" {
model = user.AIModel.String
}
// Get last generation time (simplified - could be enhanced with actual generation logs)
lastGenerationTime := time.Now().Add(-time.Hour) // Placeholder
// Get generation rate (simplified - could be enhanced with actual metrics)
generationRate := float32(2.5) // Placeholder: average questions per minute
return &api.GenerationFocus{
CurrentGenerationModel: &model,
LastGenerationTime: formatTimePtr(lastGenerationTime),
GenerationRate: &generationRate,
}, nil
}
func (h *QuizHandler) getHighPriorityTopicsForUser(ctx context.Context, userID int) ([]string, error) {
// Get high priority topics from learning service
topics, err := h.learningService.GetHighPriorityTopics(ctx, userID)
if err != nil {
return nil, err
}
return topics, nil
}
func (h *QuizHandler) getGapAnalysisForUser(ctx context.Context, userID int) (map[string]interface{}, error) {
// Get gap analysis from learning service
gapAnalysis, err := h.learningService.GetGapAnalysis(ctx, userID)
if err != nil {
return nil, err
}
return gapAnalysis, nil
}
func (h *QuizHandler) getPriorityDistributionForUser(ctx context.Context, userID int) (map[string]int, error) {
// Get priority distribution from learning service
distribution, err := h.learningService.GetPriorityDistribution(ctx, userID)
if err != nil {
return nil, err
}
return distribution, nil
}
func convertLearningPreferencesToAPI(prefs *models.UserLearningPreferences) *api.UserLearningPreferences {
out := &api.UserLearningPreferences{
FocusOnWeakAreas: prefs.FocusOnWeakAreas,
FreshQuestionRatio: float32(prefs.FreshQuestionRatio),
KnownQuestionPenalty: float32(prefs.KnownQuestionPenalty),
ReviewIntervalDays: prefs.ReviewIntervalDays,
WeakAreaBoost: float32(prefs.WeakAreaBoost),
DailyReminderEnabled: prefs.DailyReminderEnabled,
}
if prefs.TTSVoice != "" {
v := prefs.TTSVoice
out.TtsVoice = &v
}
if prefs.DailyGoal > 0 {
dg := prefs.DailyGoal
out.DailyGoal = &dg
}
return out
}
package handlers
import (
"fmt"
"net/http"
"sort"
"strings"
"time"
"quizapp/internal/observability"
"github.com/gin-gonic/gin"
)
// RouteInfo represents information about a single route
type RouteInfo struct {
Method string `json:"method"`
Path string `json:"path"`
HandlerName string `json:"handler_name"`
}
// RouteListingHandler generates automatic route listings
type RouteListingHandler struct {
serviceName string
routes []RouteInfo
}
// NewRouteListingHandler creates a new route listing handler
func NewRouteListingHandler(serviceName string) *RouteListingHandler {
return &RouteListingHandler{
serviceName: serviceName,
routes: []RouteInfo{},
}
}
// CollectRoutes extracts all routes from a Gin engine
func (h *RouteListingHandler) CollectRoutes(engine *gin.Engine) {
h.routes = []RouteInfo{}
// Get all routes from the Gin engine
routes := engine.Routes()
for _, route := range routes {
// Skip internal Gin routes
if strings.HasPrefix(route.Path, "/debug/") {
continue
}
h.routes = append(h.routes, RouteInfo{
Method: route.Method,
Path: route.Path,
HandlerName: route.Handler,
})
}
// Sort routes by path for better organization
sort.Slice(h.routes, func(i, j int) bool {
return h.routes[i].Path < h.routes[j].Path
})
}
// GetRouteListingPage shows all available routes as HTML
func (h *RouteListingHandler) GetRouteListingPage(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_route_listing_page")
defer observability.FinishSpan(span, nil)
html := h.generateHTML()
// Add no-cache headers
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Header("Pragma", "no-cache")
c.Header("Expires", "0")
c.String(http.StatusOK, html)
}
// GetRouteListingJSON returns the route listing as JSON
func (h *RouteListingHandler) GetRouteListingJSON(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_route_listing_json")
defer observability.FinishSpan(span, nil)
c.JSON(http.StatusOK, h.routes)
}
// generateHTML creates an HTML page listing all routes
func (h *RouteListingHandler) generateHTML() string {
var html strings.Builder
html.WriteString(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>` + h.serviceName + ` - Available Routes</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; line-height: 1.6; padding: 20px; background-color: #f8f9fa; color: #212529; }
.container { max-width: 1200px; margin: auto; background: #fff; padding: 30px; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.05); }
h1 { color: #0056b3; border-bottom: 2px solid #dee2e6; padding-bottom: 10px; margin-bottom: 30px; }
.service-info { background: #e7f3ff; padding: 15px; border-radius: 5px; margin-bottom: 30px; }
.route-table { width: 100%; border-collapse: collapse; margin-bottom: 30px; }
.route-table th, .route-table td { padding: 12px; text-align: left; border-bottom: 1px solid #dee2e6; }
.route-table th { background-color: #f8f9fa; font-weight: 600; color: #495057; }
.route-table tr:nth-child(even) { background-color: #f8f9fa; }
.route-table tr:hover { background-color: #e9ecef; }
.method { display: inline-block; padding: 4px 8px; border-radius: 4px; font-size: 12px; font-weight: bold; min-width: 60px; text-align: center; }
.method-get { background-color: #d4edda; color: #155724; }
.method-post { background-color: #cce5ff; color: #004085; }
.method-put { background-color: #fff3cd; color: #856404; }
.method-delete { background-color: #f8d7da; color: #721c24; }
.method-patch { background-color: #e2e3e5; color: #383d41; }
.path { font-family: "Monaco", "Menlo", "Ubuntu Mono", monospace; font-size: 14px; color: #6f42c1; }
.clickable-path { cursor: pointer; text-decoration: underline; }
.clickable-path:hover { background-color: #f8f9fa; }
.footer { margin-top: 30px; text-align: center; color: #6c757d; font-size: 14px; }
.stats { display: flex; gap: 20px; margin-bottom: 20px; }
.stat-box { background: #ffffff; border: 1px solid #dee2e6; padding: 15px; border-radius: 5px; text-align: center; flex: 1; }
.stat-number { font-size: 24px; font-weight: bold; color: #0056b3; }
.stat-label { color: #6c757d; font-size: 14px; }
</style>
</head>
<body>
<div class="container">
<h1>` + h.serviceName + ` Service - Available Routes</h1>
<div class="service-info">
<strong>Service:</strong> ` + h.serviceName + `<br>
<strong>Generated:</strong> ` + time.Now().Format("2006-01-02 15:04:05") + `<br>
<strong>Total Routes:</strong> ` + fmt.Sprintf("%d", len(h.routes)) + `
</div>
<div class="stats">
<div class="stat-box">
<div class="stat-number">` + fmt.Sprintf("%d", len(h.routes)) + `</div>
<div class="stat-label">Total Routes</div>
</div>
<div class="stat-box">
<div class="stat-number">` + fmt.Sprintf("%d", h.countMethods("GET")) + `</div>
<div class="stat-label">GET Routes</div>
</div>
<div class="stat-box">
<div class="stat-number">` + fmt.Sprintf("%d", h.countMethods("POST")) + `</div>
<div class="stat-label">POST Routes</div>
</div>
</div>
<table class="route-table">
<thead>
<tr>
<th>Method</th>
<th>Path</th>
<th>Handler</th>
</tr>
</thead>
<tbody>`)
for _, route := range h.routes {
methodClass := "method-" + strings.ToLower(route.Method)
pathClass := "path"
// Make paths clickable for GET routes
if route.Method == "GET" {
pathClass += " clickable-path"
}
html.WriteString(fmt.Sprintf(`
<tr>
<td><span class="method %s">%s</span></td>
<td><span class="%s" onclick="navigateToRoute('%s', '%s')">%s</span></td>
<td>%s</td>
</tr>`,
methodClass, route.Method,
pathClass, route.Method, route.Path, route.Path,
route.HandlerName,
))
}
html.WriteString(`
</tbody>
</table>
<div class="footer">
<p>Click on any GET route path to navigate to it | <a href="/?json=true">View as JSON</a></p>
</div>
</div>
<script>
function navigateToRoute(method, path) {
if (method === 'GET') {
window.location.href = path;
} else {
alert('Only GET routes can be navigated to directly. Use API client for ' + method + ' requests.');
}
}
</script>
</body>
</html>`)
return html.String()
}
// countMethods counts routes by HTTP method
func (h *RouteListingHandler) countMethods(method string) int {
count := 0
for _, route := range h.routes {
if route.Method == method {
count++
}
}
return count
}
package handlers
import (
"encoding/json"
"net/http"
"os"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-contrib/secure"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/version"
)
// IMPORTANT: When adding new API endpoints, make sure to:
// 1. Add them to swagger.yaml with proper documentation
// 2. Run `task generate-api-types` to regenerate types
// 3. Update any relevant tests
// 4. Consider if the endpoint should be public or admin-only
// NewRouter creates a new router factory with all the necessary middleware and routes
func NewRouter(
cfg *config.Config,
userService services.UserServiceInterface,
questionService services.QuestionServiceInterface,
learningService services.LearningServiceInterface,
aiService services.AIServiceInterface,
workerService services.WorkerServiceInterface,
dailyQuestionService services.DailyQuestionServiceInterface,
storyService services.StoryServiceInterface,
conversationService services.ConversationServiceInterface,
oauthService *services.OAuthService,
generationHintService services.GenerationHintServiceInterface,
translationService services.TranslationServiceInterface,
snippetsService services.SnippetsServiceInterface,
usageStatsService services.UsageStatsServiceInterface,
wordOfTheDayService services.WordOfTheDayServiceInterface,
authAPIKeyService services.AuthAPIKeyServiceInterface,
logger *observability.Logger,
) *gin.Engine {
// Setup Gin router
router := gin.New()
router.Use(gin.Recovery())
// Add HTTP request logging middleware using our observability logger
router.Use(func(c *gin.Context) {
start := time.Now()
// Process request
c.Next()
// Log request details using our observability logger
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Create structured log entry
fields := map[string]interface{}{
"http.method": method,
"http.path": path,
"http.status_code": statusCode,
"http.latency_ms": latency.Milliseconds(),
"http.client_ip": clientIP,
"http.user_agent": c.Request.UserAgent(),
}
// Add error message if present
if len(c.Errors) > 0 {
fields["http.error"] = c.Errors.String()
}
// For failed requests (4xx and 5xx), capture response body for debugging
if statusCode >= 400 {
// Get response body for error requests
if c.Writer.Size() > 0 {
// Try to capture response body for debugging
// Note: This is a best effort since the response may have already been written
fields["http.response_size"] = c.Writer.Size()
}
// Add more context for 5xx errors
if statusCode >= 500 {
fields["http.error_type"] = "server_error"
// Log additional context that might help debugging
if c.Request.Body != nil {
fields["http.request_has_body"] = true
}
} else {
fields["http.error_type"] = "client_error"
}
}
// Log using our observability logger (goes to both stdout and OTLP)
// Use appropriate log level based on status code
if statusCode >= 500 {
logger.Error(c.Request.Context(), "HTTP request failed", nil, fields)
} else if statusCode >= 400 {
logger.Warn(c.Request.Context(), "HTTP request warning", fields)
} else {
logger.Info(c.Request.Context(), "HTTP request", fields)
}
})
// Health check endpoint (defined before any middleware)
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "service": "backend"})
})
// Add OpenTelemetry middleware for HTTP tracing and context propagation with automatic error attributes
router.Use(observability.GinMiddlewareWithErrorHandling("quiz-backend"))
// Add response validation middleware for API endpoints
router.Use(middleware.ResponseValidationMiddleware(logger))
// Swagger documentation (defined before middleware)
router.StaticFile("/swagger.yaml", "./swagger.yaml")
router.StaticFile("/swaggerz", "./swaggerz.html")
// Disable automatic redirection for trailing slashes, which is better for APIs
router.RedirectTrailingSlash = false
// Setup CORS middleware
corsConfig := cors.DefaultConfig()
corsConfig.AllowOrigins = cfg.Server.CORSOrigins
corsConfig.AllowCredentials = true
corsConfig.AllowHeaders = []string{"Origin", "Content-Length", "Content-Type", "Authorization", "X-Requested-With"}
corsConfig.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
router.Use(cors.New(corsConfig))
// Setup session middleware
store := cookie.NewStore([]byte(cfg.Server.SessionSecret))
// Configure session options for security
sessionOpts := sessions.Options{
Path: config.SessionPath,
MaxAge: int(config.SessionMaxAge.Seconds()),
HttpOnly: config.SessionHTTPOnly,
Secure: config.SessionSecure, // Set to true in production with HTTPS
}
if cfg.Server.Debug {
sessionOpts.SameSite = http.SameSiteDefaultMode
} else {
sessionOpts.SameSite = http.SameSiteLaxMode
sessionOpts.Secure = true
}
store.Options(sessionOpts)
router.Use(sessions.Sessions(config.SessionName, store))
// Setup Gin mode
gin.SetMode(gin.ReleaseMode)
if cfg.Server.Debug {
gin.SetMode(gin.DebugMode)
}
// Security middleware
secureConfig := secure.DefaultConfig()
secureConfig.SSLRedirect = false
secureConfig.ContentSecurityPolicy = config.DefaultCSP
router.Use(secure.New(secureConfig))
// Serve all static assets (JS, fonts, CSS, etc.) from /backend/*filepath
// Note: Static assets are now served from the frontend build
// Initialize handlers
authHandler := NewAuthHandler(userService, oauthService, cfg, logger)
authAPIKeyHandler := NewAuthAPIKeyHandler(authAPIKeyService, logger)
emailService := services.CreateEmailService(cfg, logger)
settingsHandler := NewSettingsHandler(userService, storyService, conversationService, aiService, learningService, emailService, usageStatsService, cfg, logger)
quizHandler := NewQuizHandler(userService, questionService, aiService, learningService, workerService, generationHintService, usageStatsService, cfg, logger)
dailyQuestionHandler := NewDailyQuestionHandler(userService, dailyQuestionService, cfg, logger)
storyHandler := NewStoryHandler(storyService, userService, aiService, cfg, logger)
aiConversationHandler := NewAIConversationHandler(conversationService, cfg, logger)
translationHandler := NewTranslationHandler(translationService, cfg, logger)
snippetsHandler := NewSnippetsHandler(snippetsService, cfg, logger)
wordOfTheDayHandler := NewWordOfTheDayHandler(userService, wordOfTheDayService, cfg, logger)
adminHandler := NewAdminHandlerWithLogger(userService, questionService, aiService, cfg, learningService, workerService, logger, usageStatsService)
// Inject story service into admin handler via exported field
adminHandler.storyService = storyService
userAdminHandler := NewUserAdminHandler(userService, cfg, logger)
verbConjugationHandler := NewVerbConjugationHandler(logger)
feedbackService := services.NewFeedbackService(userService.GetDB(), logger)
// Initialize Linear service if enabled
var linearService *services.LinearService
if cfg.Linear.Enabled {
linearService = services.NewLinearService(cfg, logger)
}
feedbackHandler := NewFeedbackHandler(feedbackService, linearService, userService, cfg, logger)
// V1 routes (matching swagger spec)
v1 := router.Group("/v1")
{
// Version aggregation endpoint (no auth)
v1.GET("/version", func(c *gin.Context) {
backendVersion := gin.H{
"service": "backend",
"version": version.Version,
"commit": version.Commit,
"buildTime": version.BuildTime,
}
workerInternalURL := os.Getenv("WORKER_INTERNAL_URL")
if workerInternalURL == "" {
workerInternalURL = cfg.Server.WorkerInternalURL // fallback
}
// Use instrumented HTTP client for tracing
client := &http.Client{
Transport: otelhttp.NewTransport(http.DefaultTransport),
}
req, err := http.NewRequest("GET", workerInternalURL+"/v1/version", nil)
var workerResp *http.Response
if err == nil {
req = req.WithContext(c.Request.Context())
workerResp, err = client.Do(req)
}
var workerVersion interface{}
if err == nil && workerResp.StatusCode == http.StatusOK {
defer func() { _ = workerResp.Body.Close() }()
if err := json.NewDecoder(workerResp.Body).Decode(&workerVersion); err != nil {
workerVersion = gin.H{"error": "Failed to decode worker version"}
}
} else {
workerVersion = gin.H{"error": "Worker unavailable"}
}
c.JSON(http.StatusOK, gin.H{
"backend": backendVersion,
"worker": workerVersion,
})
})
auth := v1.Group("/auth")
{
auth.POST("/login", middleware.RequestValidationMiddleware(logger), authHandler.Login)
auth.POST("/logout", authHandler.Logout)
auth.GET("/status", authHandler.Status)
auth.GET("/check", middleware.RequireAuth(), authHandler.Check)
auth.POST("/signup", middleware.RequestValidationMiddleware(logger), authHandler.Signup)
auth.GET("/signup/status", authHandler.SignupStatus)
auth.GET("/google/login", authHandler.GoogleLogin)
auth.GET("/google/callback", authHandler.GoogleCallback)
}
// API Keys routes (for programmatic API access)
apiKeys := v1.Group("/api-keys")
apiKeys.Use(middleware.RequireAuth()) // Keep session-only auth for managing API keys
{
apiKeys.POST("", middleware.RequestValidationMiddleware(logger), authAPIKeyHandler.CreateAPIKey)
apiKeys.GET("", authAPIKeyHandler.ListAPIKeys)
apiKeys.DELETE("/:id", authAPIKeyHandler.DeleteAPIKey)
}
// API Key test endpoints using API key auth (no cookies)
apiKeysTest := v1.Group("/api-keys")
apiKeysTest.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
{
apiKeysTest.GET("/test-read", authAPIKeyHandler.TestRead)
apiKeysTest.POST("/test-write", authAPIKeyHandler.TestWrite)
}
// Translation routes
v1.POST("/translate", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), translationHandler.TranslateText)
// Feedback routes
v1.POST("/feedback", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), feedbackHandler.SubmitFeedback)
// Snippets routes
v1.POST("/snippets", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), snippetsHandler.CreateSnippet)
v1.GET("/snippets", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.GetSnippets)
v1.DELETE("/snippets", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.DeleteAllSnippets)
v1.GET("/snippets/search", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.SearchSnippets)
v1.GET("/snippets/by-question/:question_id", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.GetSnippetsByQuestion)
v1.GET("/snippets/by-section/:section_id", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.GetSnippetsBySection)
v1.GET("/snippets/by-story/:story_id", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.GetSnippetsByStory)
v1.GET("/snippets/:id", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.GetSnippet)
v1.PUT("/snippets/:id", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), snippetsHandler.UpdateSnippet)
v1.DELETE("/snippets/:id", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), snippetsHandler.DeleteSnippet)
quiz := v1.Group("/quiz")
quiz.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
quiz.Use(middleware.RequestValidationMiddleware(logger))
{
quiz.GET("/question", quizHandler.GetQuestion)
quiz.GET("/question/:id", quizHandler.GetQuestion)
quiz.POST("/question/:id/report", quizHandler.ReportQuestion)
quiz.POST("/question/:id/mark-known", quizHandler.MarkQuestionAsKnown)
quiz.POST("/answer", quizHandler.SubmitAnswer)
quiz.GET("/progress", quizHandler.GetProgress)
quiz.GET("/ai-token-usage", quizHandler.GetAITokenUsage)
quiz.GET("/ai-token-usage/daily", quizHandler.GetAITokenUsageDaily)
quiz.GET("/ai-token-usage/hourly", quizHandler.GetAITokenUsageHourly)
quiz.GET("/worker-status", quizHandler.GetWorkerStatus)
quiz.POST("/chat/stream", quizHandler.ChatStream)
}
daily := v1.Group("/daily")
daily.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
daily.Use(middleware.RequestValidationMiddleware(logger))
{
daily.GET("/questions/:date", dailyQuestionHandler.GetDailyQuestions)
daily.POST("/questions/:date/complete/:questionId", dailyQuestionHandler.MarkQuestionCompleted)
daily.DELETE("/questions/:date/complete/:questionId", dailyQuestionHandler.ResetQuestionCompleted)
daily.POST("/questions/:date/answer/:questionId", dailyQuestionHandler.SubmitDailyQuestionAnswer)
daily.GET("/history/:questionId", dailyQuestionHandler.GetQuestionHistory)
daily.GET("/dates", dailyQuestionHandler.GetAvailableDates)
daily.GET("/progress/:date", dailyQuestionHandler.GetDailyProgress)
// Note: Assignment is handled automatically by the worker
}
wordOfDay := v1.Group("/word-of-day")
{
// Protected endpoints requiring authentication (API key or session)
wordOfDay.GET("", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), wordOfTheDayHandler.GetWordOfTheDayToday)
wordOfDay.GET("/history", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), wordOfTheDayHandler.GetWordOfTheDayHistory)
// Embed endpoint supports optional date query parameter and requires API key or session auth
wordOfDay.GET("/embed", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), wordOfTheDayHandler.GetWordOfTheDayEmbed)
wordOfDay.GET("/:date/embed", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), wordOfTheDayHandler.GetWordOfTheDayEmbed)
wordOfDay.GET("/:date", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), wordOfTheDayHandler.GetWordOfTheDay)
}
story := v1.Group("/story")
story.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
story.Use(middleware.RequestValidationMiddleware(logger))
{
story.POST("", storyHandler.CreateStory)
story.GET("", storyHandler.GetUserStories)
story.GET("/current", storyHandler.GetCurrentStory)
story.GET("/:id", storyHandler.GetStory)
story.GET("/section/:id", storyHandler.GetSection)
story.POST("/:id/generate", storyHandler.GenerateNextSection)
story.POST("/:id/archive", storyHandler.ArchiveStory)
story.POST("/:id/complete", storyHandler.CompleteStory)
story.POST("/:id/set-current", storyHandler.SetCurrentStory)
story.POST("/:id/toggle-auto-generation", storyHandler.ToggleAutoGeneration)
story.DELETE("/:id", storyHandler.DeleteStory)
story.GET("/:id/export", storyHandler.ExportStory)
}
settings := v1.Group("/settings")
{
settings.GET("/ai-providers", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), settingsHandler.GetProviders)
settings.GET("/levels", settingsHandler.GetLevels)
settings.GET("/languages", settingsHandler.GetLanguages)
settings.POST("/test-ai", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.TestAIConnection)
settings.POST("/test-email", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.SendTestEmail)
settings.PUT("", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.UpdateUserSettings)
settings.PUT("/word-of-day-email", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.UpdateWordOfDayEmailPreference)
// User data management endpoints
settings.POST("/clear-stories", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.ClearAllStories)
settings.POST("/clear-ai-chats", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.ClearAllAIChats)
settings.POST("/reset-account", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), settingsHandler.ResetAccount)
settings.GET("/api-key/:provider", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), settingsHandler.CheckAPIKeyAvailability)
}
// Verb conjugation endpoints
verbConjugations := v1.Group("/verb-conjugations")
verbConjugations.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
{
verbConjugations.GET("/info", verbConjugationHandler.GetVerbConjugationInfo)
verbConjugations.GET("/languages", verbConjugationHandler.GetAvailableLanguages)
verbConjugations.GET("/:language", verbConjugationHandler.GetVerbConjugations)
verbConjugations.GET("/:language/:verb", verbConjugationHandler.GetVerbConjugation)
}
// AI conversation endpoints
ai := v1.Group("/ai")
ai.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
ai.Use(middleware.RequestValidationMiddleware(logger))
{
ai.GET("/conversations", aiConversationHandler.GetConversations)
ai.POST("/conversations", aiConversationHandler.CreateConversation)
ai.GET("/conversations/:id", aiConversationHandler.GetConversation)
ai.PUT("/conversations/:id", aiConversationHandler.UpdateConversation)
ai.DELETE("/conversations/:id", aiConversationHandler.DeleteConversation)
ai.POST("/conversations/:conversationId/messages", aiConversationHandler.AddMessage)
ai.PUT("/conversations/bookmark", aiConversationHandler.ToggleMessageBookmark)
ai.GET("/search", aiConversationHandler.SearchConversations)
ai.GET("/bookmarks", aiConversationHandler.GetBookmarkedMessages)
}
preferences := v1.Group("/preferences")
preferences.Use(middleware.RequireAuthWithAPIKey(authAPIKeyService, userService))
preferences.Use(middleware.RequestValidationMiddleware(logger))
{
preferences.GET("/learning", settingsHandler.GetLearningPreferences)
preferences.PUT("/learning", settingsHandler.UpdateLearningPreferences)
}
// User management endpoints (non-admin)
userz := v1.Group("/userz")
{
userz.PUT("/profile", middleware.RequireAuthWithAPIKey(authAPIKeyService, userService), middleware.RequestValidationMiddleware(logger), userAdminHandler.UpdateCurrentUserProfile)
}
// Admin endpoints
admin := v1.Group("/admin")
admin.Use(middleware.RequireAdmin(userService))
admin.Use(middleware.RequestValidationMiddleware(logger))
{
// Backend admin endpoints
backend := admin.Group("/backend")
{
// Backend admin page
backend.GET("", adminHandler.GetBackendAdminPage)
// Feedback management (admin only)
backend.GET("/feedback", feedbackHandler.ListFeedback)
backend.GET("/feedback/:id", feedbackHandler.GetFeedback)
backend.PATCH("/feedback/:id", feedbackHandler.UpdateFeedback)
backend.DELETE("/feedback/:id", feedbackHandler.DeleteFeedback)
backend.DELETE("/feedback", func(c *gin.Context) {
// Check if it's a delete all request
if c.Query("all") == "true" {
feedbackHandler.DeleteAllFeedback(c)
} else {
feedbackHandler.DeleteFeedbackByStatus(c)
}
})
backend.POST("/feedback/:id/linear-issue", feedbackHandler.CreateLinearIssue)
// User management (admin only)
backend.GET("/userz", userAdminHandler.GetAllUsers)
backend.GET("/userz/paginated", userAdminHandler.GetUsersPaginated)
backend.POST("/userz", userAdminHandler.CreateUser)
backend.PUT("/userz/:id", userAdminHandler.UpdateUser)
backend.DELETE("/userz/:id", userAdminHandler.DeleteUser)
backend.POST("/userz/:id/reset-password", userAdminHandler.ResetUserPassword)
// Role management endpoints
backend.GET("/roles", adminHandler.GetRoles)
backend.GET("/userz/:id/roles", adminHandler.GetUserRoles)
backend.POST("/userz/:id/roles", adminHandler.AssignRole)
backend.DELETE("/userz/:id/roles/:roleId", adminHandler.RemoveRole)
// Admin dashboard data
backend.GET("/dashboard", adminHandler.GetBackendAdminData)
backend.GET("/ai-concurrency", adminHandler.GetAIConcurrencyStats)
// Question management
backend.GET("/questions/:id", adminHandler.GetQuestion)
backend.GET("/questions/:id/users", adminHandler.GetUsersForQuestion)
backend.PUT("/questions/:id", adminHandler.UpdateQuestion)
backend.DELETE("/questions/:id", adminHandler.DeleteQuestion)
backend.POST("/questions/:id/assign-users", adminHandler.AssignUsersToQuestion)
backend.POST("/questions/:id/unassign-users", adminHandler.UnassignUsersFromQuestion)
backend.GET("/questions/paginated", adminHandler.GetQuestionsPaginated)
backend.GET("/questions", adminHandler.GetAllQuestions)
backend.GET("/reported-questions", adminHandler.GetReportedQuestionsPaginated)
backend.POST("/questions/:id/fix", adminHandler.MarkQuestionAsFixed)
backend.POST("/questions/:id/ai-fix", adminHandler.FixQuestionWithAI)
// Data management
backend.POST("/clear-user-data", adminHandler.ClearUserData)
backend.POST("/clear-database", adminHandler.ClearDatabase)
backend.POST("/userz/:id/clear", adminHandler.ClearUserDataForUser)
// Story explorer (admin)
backend.GET("/stories", adminHandler.GetStoriesPaginated)
backend.GET("/stories/:id", adminHandler.GetStoryAdmin)
backend.DELETE("/stories/:id", adminHandler.DeleteStoryAdmin)
backend.GET("/story-sections/:id", adminHandler.GetSectionAdmin)
// Usage stats (admin)
backend.GET("/usage-stats", adminHandler.GetUsageStats)
backend.GET("/usage-stats/:service", adminHandler.GetUsageStatsByService)
}
}
}
// Config dump endpoint
router.GET("/configz", adminHandler.GetConfigz)
// Serve frontend static files
router.Static("/assets", "./frontend/dist/assets")
router.StaticFile("/favicon.svg", "./frontend/dist/favicon.svg")
router.StaticFile("/fonts", "./frontend/dist/fonts")
// Catch-all route for SPA - serve index.html for any route that doesn't match API routes
router.NoRoute(func(c *gin.Context) {
// Don't serve index.html for API routes
if strings.HasPrefix(c.Request.URL.Path, "/v1/") ||
strings.HasPrefix(c.Request.URL.Path, "/configz") ||
strings.HasPrefix(c.Request.URL.Path, "/swagger") ||
strings.HasPrefix(c.Request.URL.Path, "/backend/") {
c.JSON(http.StatusNotFound, gin.H{"error": "Not found"})
return
}
// Serve the frontend's index.html for all other routes
c.File("./frontend/dist/index.html")
})
// Automatic route listing at root path
routeListing := NewRouteListingHandler("Backend")
routeListing.CollectRoutes(router)
// Root path shows all available routes
router.GET("/", func(c *gin.Context) {
if c.Query("json") == "true" {
routeListing.GetRouteListingJSON(c)
} else {
routeListing.GetRouteListingPage(c)
}
})
return router
}
package handlers
import (
"quizapp/internal/middleware"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// GetUserIDFromSession retrieves the current user ID from the session or context.
// Returns (0, false) if not authenticated or if the stored value is invalid.
func GetUserIDFromSession(c *gin.Context) (int, bool) {
// First check if user ID is already in context (set by auth middleware)
if userIDVal, exists := c.Get(middleware.UserIDKey); exists {
if id, ok := userIDVal.(int); ok {
return id, true
}
// Try to convert from uint (common in tests)
if idUint, ok := userIDVal.(uint); ok {
return int(idUint), true
}
// If it's some other type in context, it's invalid
return 0, false
}
// Fall back to session if not in context (maintain original behavior for sessions)
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
if userID == nil {
return 0, false
}
id, ok := userID.(int)
if !ok {
return 0, false
}
return id, true
}
// GetUsernameFromSession retrieves the current user username from the session or context.
// Returns (0, false) if not authenticated or if the stored value is invalid.
func GetUsernameFromSession(c *gin.Context) (string, bool) {
// First check if user ID is already in context (set by auth middleware)
if usernameVal, exists := c.Get(middleware.UsernameKey); exists {
if username, ok := usernameVal.(string); ok {
return username, true
}
return "", false
}
// Fall back to session if not in context (maintain original behavior for sessions)
session := sessions.Default(c)
username := session.Get(middleware.UsernameKey)
if username == nil {
return "", false
}
usernameStr, ok := username.(string)
if !ok {
return "", false
}
return usernameStr, true
}
package handlers
import (
"fmt"
"net/http"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/services/mailer"
contextutils "quizapp/internal/utils"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// SettingsHandler handles user settings related HTTP requests
type SettingsHandler struct {
userService services.UserServiceInterface
storyService services.StoryServiceInterface
conversationService services.ConversationServiceInterface
aiService services.AIServiceInterface
learningService services.LearningServiceInterface
usageStatsSvc services.UsageStatsServiceInterface
emailService mailer.Mailer
cfg *config.Config
logger *observability.Logger
}
// NewSettingsHandler creates a new SettingsHandler instance
func NewSettingsHandler(userService services.UserServiceInterface, storyService services.StoryServiceInterface, conversationService services.ConversationServiceInterface, aiService services.AIServiceInterface, learningService services.LearningServiceInterface, emailService mailer.Mailer, usageStatsSvc services.UsageStatsServiceInterface, cfg *config.Config, logger *observability.Logger) *SettingsHandler {
return &SettingsHandler{
userService: userService,
storyService: storyService,
conversationService: conversationService,
aiService: aiService,
learningService: learningService,
usageStatsSvc: usageStatsSvc,
emailService: emailService,
cfg: cfg,
logger: logger,
}
}
// UpdateWordOfDayEmailPreference updates the user's word-of-day email preference
func (h *SettingsHandler) UpdateWordOfDayEmailPreference(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "update_word_of_day_email_preference")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var body struct {
Enabled bool `json:"enabled"`
}
if err := c.ShouldBindJSON(&body); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
if err := h.userService.UpdateWordOfDayEmailEnabled(ctx, userID, body.Enabled); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to update word of day email preference"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// UpdateUserSettings handles updating user settings
func (h *SettingsHandler) UpdateUserSettings(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "update_user_settings")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var settings api.UserSettings
if err := c.ShouldBindJSON(&settings); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Validate that at least one meaningful field is provided
// Avoid relying on generated union/raw fields that may be non-nil for an empty JSON body
hasAnyField := settings.Language != nil ||
settings.Level != nil ||
settings.AiProvider != nil ||
settings.AiModel != nil ||
settings.ApiKey != nil ||
settings.AiEnabled != nil
if !hasAnyField {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Convert api.UserSettings to models.UserSettings
modelSettings := models.UserSettings{}
if settings.Language != nil {
modelSettings.Language = string(*settings.Language)
span.SetAttributes(attribute.String("settings.language", modelSettings.Language))
}
if settings.Level != nil {
modelSettings.Level = string(*settings.Level)
span.SetAttributes(attribute.String("settings.level", modelSettings.Level))
}
if settings.AiProvider != nil {
modelSettings.AIProvider = *settings.AiProvider
span.SetAttributes(attribute.String("settings.ai_provider", modelSettings.AIProvider))
}
if settings.AiModel != nil {
modelSettings.AIModel = *settings.AiModel
span.SetAttributes(attribute.String("settings.ai_model", modelSettings.AIModel))
}
if settings.ApiKey != nil {
modelSettings.AIAPIKey = *settings.ApiKey
span.SetAttributes(attribute.Bool("settings.api_key_provided", true))
}
if settings.AiEnabled != nil {
modelSettings.AIEnabled = *settings.AiEnabled
span.SetAttributes(attribute.Bool("settings.ai_enabled", modelSettings.AIEnabled))
}
// Validate level if provided (including empty string)
if settings.Level != nil {
validLevels := h.cfg.GetAllLevels()
isValidLevel := false
for _, level := range validLevels {
if modelSettings.Level == level {
isValidLevel = true
break
}
}
if !isValidLevel {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
// Validate language if provided (including empty string)
if settings.Language != nil {
validLanguages := h.cfg.GetLanguages()
isValidLanguage := false
for _, language := range validLanguages {
if modelSettings.Language == language {
isValidLanguage = true
break
}
}
if !isValidLanguage {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
if err := h.userService.UpdateUserSettings(c.Request.Context(), userID, &modelSettings); err != nil {
// Check if the error is due to user not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update settings"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{Success: true})
}
// TestAIConnection tests the AI service connection with provided settings
func (h *SettingsHandler) TestAIConnection(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "test_ai_connection")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req api.TestAIRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
// Extract values from API request
provider := req.Provider
model := req.Model
apiKey := ""
if req.ApiKey != nil {
apiKey = *req.ApiKey
}
// If API key is empty, try to use the saved one from the new user_api_keys table
if apiKey == "" {
savedKey, err := h.userService.GetUserAPIKey(c.Request.Context(), userID, provider)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get saved API key"))
return
}
apiKey = savedKey
}
err := h.aiService.TestConnection(c.Request.Context(), provider, model, apiKey)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("Model '%s': %s", model, err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Connection successful"})
}
// GetProviders returns the available AI provider configurations
func (h *SettingsHandler) GetProviders(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_providers")
defer observability.FinishSpan(span, nil)
response := gin.H{
"providers": h.cfg.Providers,
"levels": h.cfg.GetAllLevels(),
"languages": h.cfg.GetLanguages(),
}
c.JSON(http.StatusOK, response)
}
// GetLevels returns the available levels and their descriptions.
func (h *SettingsHandler) GetLevels(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_levels")
defer observability.FinishSpan(span, nil)
language := c.Query("language")
if language != "" {
levels := h.cfg.GetLevelsForLanguage(language)
descriptions := h.cfg.GetLevelDescriptionsForLanguage(language)
c.JSON(http.StatusOK, gin.H{
"levels": levels,
"level_descriptions": descriptions,
})
return
}
c.JSON(http.StatusOK, gin.H{
"levels": h.cfg.GetAllLevels(),
"level_descriptions": h.cfg.GetAllLevelDescriptions(),
})
}
// GetLanguages returns the available languages.
func (h *SettingsHandler) GetLanguages(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_languages")
defer observability.FinishSpan(span, nil)
c.JSON(http.StatusOK, h.cfg.GetLanguageInfoList())
}
// CheckAPIKeyAvailability checks if the user has a saved API key for the specified provider
func (h *SettingsHandler) CheckAPIKeyAvailability(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "check_api_key_availability")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
provider := c.Param("provider")
if provider == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Check if user has a saved API key for this provider
hasAPIKey, err := h.userService.HasUserAPIKey(ctx, userID, provider)
if err != nil {
h.logger.Error(ctx, "Failed to check API key availability", err, map[string]interface{}{
"user_id": userID,
"provider": provider,
})
HandleAppError(c, contextutils.WrapError(err, "failed to check API key availability"))
return
}
c.JSON(http.StatusOK, gin.H{"has_api_key": hasAPIKey})
}
// GetLearningPreferences retrieves user learning preferences
func (h *SettingsHandler) GetLearningPreferences(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_learning_preferences")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
preferences, err := h.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get learning preferences", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get learning preferences"))
return
}
// Convert backend model to API schema
apiPreferences := convertLearningPreferencesToAPI(preferences)
c.JSON(http.StatusOK, apiPreferences)
}
// UpdateLearningPreferences updates user learning preferences
func (h *SettingsHandler) UpdateLearningPreferences(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "update_learning_preferences")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req models.UserLearningPreferences
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Set the user ID
req.UserID = userID
// Set span attributes for updated preferences
span.SetAttributes(
attribute.Bool("learning.focus_on_weak_areas", req.FocusOnWeakAreas),
attribute.Bool("learning.include_review_questions", req.IncludeReviewQuestions),
attribute.Float64("learning.fresh_question_ratio", req.FreshQuestionRatio),
attribute.Float64("learning.known_question_penalty", req.KnownQuestionPenalty),
attribute.Int("learning.review_interval_days", req.ReviewIntervalDays),
attribute.Float64("learning.weak_area_boost", req.WeakAreaBoost),
)
// Update preferences in database
updatedPrefs, err := h.learningService.UpdateUserLearningPreferences(ctx, userID, &req)
if err != nil {
h.logger.Error(ctx, "Failed to update learning preferences", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to update learning preferences"))
return
}
// Convert backend model to API schema and return
apiPreferences := convertLearningPreferencesToAPI(updatedPrefs)
c.JSON(http.StatusOK, apiPreferences)
}
// SendTestEmail sends a test email to the current user
func (h *SettingsHandler) SendTestEmail(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "send_test_email")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Get the current user
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for test email", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Check if user has an email address
if !user.Email.Valid || user.Email.String == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Check if email service is enabled
if !h.emailService.IsEnabled() {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
// Send test email
err = h.emailService.SendEmail(ctx, user.Email.String, "Test Email from Quiz App", "test_email", map[string]interface{}{
"Username": user.Username,
"TestTime": "now",
"Message": "This is a test email to verify your email settings are working correctly.",
})
if err != nil {
h.logger.Error(ctx, "Failed to send test email", err, map[string]interface{}{
"user_id": userID,
"email": user.Email.String,
})
HandleAppError(c, contextutils.WrapError(err, "failed to send test email"))
return
}
h.logger.Info(ctx, "Test email sent successfully", map[string]interface{}{
"user_id": userID,
"email": user.Email.String,
})
c.JSON(http.StatusOK, api.SuccessResponse{Success: true, Message: stringPtr("Test email sent successfully")})
}
// ClearAllStories deletes all stories belonging to the current user
func (h *SettingsHandler) ClearAllStories(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "clear_all_stories")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Use the story service to delete all stories for this user
if h.storyService == nil {
h.logger.Warn(ctx, "Story service not available for ClearAllStories")
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Clear all stories not available",
"",
nil,
))
return
}
if err := h.storyService.DeleteAllStoriesForUser(ctx, uint(userID)); err != nil {
h.logger.Error(ctx, "Failed to delete all stories for user", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to delete all stories for user"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "All stories deleted successfully"})
}
// ResetAccount deletes all stories and clears user-specific data (questions, stats)
func (h *SettingsHandler) ResetAccount(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "reset_account")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Reset account: clear user data (questions, responses, metrics) and delete stories
// First, clear user data (uses userService)
if err := h.userService.ClearUserDataForUser(ctx, userID); err != nil {
h.logger.Error(ctx, "Failed to clear user data for user during reset", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to clear user data"))
return
}
// Then delete all stories
if h.storyService == nil {
h.logger.Warn(ctx, "Story service not available for ResetAccount")
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Reset account not available",
"",
nil,
))
return
}
if err := h.storyService.DeleteAllStoriesForUser(ctx, uint(userID)); err != nil {
h.logger.Error(ctx, "Failed to delete stories during reset account", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to delete stories during reset"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Account reset successfully"})
}
// ClearAllAIChats deletes all AI conversations and messages for the current user
func (h *SettingsHandler) ClearAllAIChats(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "clear_all_ai_chats")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Use the conversation service to delete all conversations for this user
if h.conversationService == nil {
h.logger.Warn(ctx, "Conversation service not available for ClearAllAIChats")
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Clear all AI chats not available",
"",
nil,
))
return
}
// Get all conversation IDs for this user
conversations, _, err := h.conversationService.GetUserConversations(ctx, uint(userID), 1000, 0) // Get max 1000 to avoid issues
if err != nil {
h.logger.Error(ctx, "Failed to get user conversations for deletion", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user conversations for deletion"))
return
}
// Delete each conversation
deletedCount := 0
for _, conversation := range conversations {
err := h.conversationService.DeleteConversation(ctx, conversation.Id.String(), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to delete conversation", err, map[string]interface{}{
"user_id": userID,
"conversation_id": conversation.Id.String(),
})
// Continue with other conversations even if one fails
} else {
deletedCount++
}
}
h.logger.Info(ctx, "Deleted AI conversations for user", map[string]interface{}{
"user_id": userID,
"deleted_count": deletedCount,
"total_count": len(conversations),
})
c.JSON(http.StatusOK, api.SuccessResponse{
Message: stringPtr(fmt.Sprintf("Deleted %d AI conversations successfully", deletedCount)),
Success: true,
})
}
package handlers
import (
"net/http"
"strconv"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// SnippetsHandler handles snippets related HTTP requests
type SnippetsHandler struct {
snippetsService services.SnippetsServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewSnippetsHandler creates a new SnippetsHandler instance
func NewSnippetsHandler(snippetsService services.SnippetsServiceInterface, cfg *config.Config, logger *observability.Logger) *SnippetsHandler {
return &SnippetsHandler{
snippetsService: snippetsService,
cfg: cfg,
logger: logger,
}
}
// CreateSnippet handles POST /v1/snippets
func (h *SnippetsHandler) CreateSnippet(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "create_snippet")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
span.SetAttributes(attribute.String("user.username", username))
var req api.CreateSnippetRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn(ctx, "Invalid create snippet request format", map[string]interface{}{
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
snippet, err := h.snippetsService.CreateSnippet(ctx, int64(userID), req)
if err != nil {
h.logger.Error(ctx, "Failed to create snippet", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, err)
return
}
// Convert to API response format
response := api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
SectionId: snippet.SectionID,
StoryId: snippet.StoryID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
}
span.SetAttributes(
attribute.Int64("snippet.id", snippet.ID),
attribute.Int64("user.id", int64(userID)),
attribute.String("snippet.original_text", snippet.OriginalText),
attribute.String("snippet.translated_text", snippet.TranslatedText),
attribute.String("snippet.source_language", snippet.SourceLanguage),
attribute.String("snippet.target_language", snippet.TargetLanguage),
)
if snippet.QuestionID != nil {
span.SetAttributes(attribute.Int64("snippet.question_id", *snippet.QuestionID))
}
if snippet.Context != nil {
span.SetAttributes(attribute.String("snippet.context", *snippet.Context))
}
if snippet.DifficultyLevel != nil {
span.SetAttributes(attribute.String("snippet.difficulty_level", *snippet.DifficultyLevel))
}
c.JSON(http.StatusCreated, response)
}
// GetSnippets handles GET /v1/snippets
func (h *SnippetsHandler) GetSnippets(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "get_snippets")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
span.SetAttributes(attribute.String("user.username", username))
// Parse query parameters
params := api.GetV1SnippetsParams{}
if q := c.Query("q"); q != "" {
params.Q = &q
}
if sourceLang := c.Query("source_lang"); sourceLang != "" {
params.SourceLang = &sourceLang
}
if targetLang := c.Query("target_lang"); targetLang != "" {
params.TargetLang = &targetLang
}
if storyIDStr := c.Query("story_id"); storyIDStr != "" {
if storyID, err := strconv.ParseInt(storyIDStr, 10, 64); err == nil {
params.StoryId = &storyID
}
}
if level := c.Query("level"); level != "" {
params.Level = (*api.GetV1SnippetsParamsLevel)(&level)
}
if limitStr := c.Query("limit"); limitStr != "" {
if limit, err := strconv.Atoi(limitStr); err == nil {
params.Limit = &limit
}
}
if offsetStr := c.Query("offset"); offsetStr != "" {
if offset, err := strconv.Atoi(offsetStr); err == nil {
params.Offset = &offset
}
}
if params.Limit != nil {
span.SetAttributes(attribute.Int("params.limit", *params.Limit))
}
if params.Offset != nil {
span.SetAttributes(attribute.Int("params.offset", *params.Offset))
}
if q := params.Q; q != nil {
span.SetAttributes(attribute.String("params.q", *q))
}
if sourceLang := params.SourceLang; sourceLang != nil {
span.SetAttributes(attribute.String("params.source_lang", *sourceLang))
}
if targetLang := params.TargetLang; targetLang != nil {
span.SetAttributes(attribute.String("params.target_lang", *targetLang))
}
if storyID := params.StoryId; storyID != nil {
span.SetAttributes(attribute.Int64("params.story_id", *storyID))
}
if level := params.Level; level != nil {
span.SetAttributes(attribute.String("params.level", string(*level)))
}
snippetList, err := h.snippetsService.GetSnippets(ctx, int64(userID), params)
if err != nil {
h.logger.Error(ctx, "Failed to get snippets", err, map[string]any{
"user_id": userID,
})
HandleAppError(c, err)
return
}
c.JSON(http.StatusOK, snippetList)
}
// GetSnippetsByQuestion handles GET /v1/snippets/by-question/:question_id
func (h *SnippetsHandler) GetSnippetsByQuestion(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "get_snippets_by_question")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
span.SetAttributes(attribute.String("user.username", username))
// Parse question_id from path parameter
questionIDStr := c.Param("question_id")
questionID, err := strconv.ParseInt(questionIDStr, 10, 64)
if err != nil {
h.logger.Warn(ctx, "Invalid question_id parameter", map[string]any{
"question_id": questionIDStr,
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
span.SetAttributes(attribute.Int64("question.id", questionID))
// Get snippets for this question
snippets, err := h.snippetsService.GetSnippetsByQuestion(ctx, int64(userID), questionID)
if err != nil {
h.logger.Error(ctx, "Failed to get snippets by question", err, map[string]any{
"user_id": userID,
"question_id": questionID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get snippets by question"))
return
}
// Return response with snippets array
response := gin.H{
"snippets": snippets,
}
c.JSON(http.StatusOK, response)
}
// GetSnippetsBySection handles GET /v1/snippets/by-section/:section_id
func (h *SnippetsHandler) GetSnippetsBySection(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "get_snippets_by_section")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
span.SetAttributes(attribute.String("user.username", username))
// Parse section_id from path parameter
sectionIDStr := c.Param("section_id")
sectionID, err := strconv.ParseInt(sectionIDStr, 10, 64)
if err != nil {
h.logger.Warn(ctx, "Invalid section_id parameter", map[string]any{
"section_id": sectionIDStr,
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
span.SetAttributes(attribute.Int64("section.id", sectionID))
// Get snippets for this section
snippets, err := h.snippetsService.GetSnippetsBySection(ctx, int64(userID), sectionID)
if err != nil {
h.logger.Error(ctx, "Failed to get snippets by section", err, map[string]any{
"user_id": userID,
"section_id": sectionID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get snippets by section"))
return
}
// Return response with snippets array
response := gin.H{
"snippets": snippets,
}
c.JSON(http.StatusOK, response)
}
// GetSnippetsByStory handles GET /v1/snippets/by-story/:story_id
func (h *SnippetsHandler) GetSnippetsByStory(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "get_snippets_by_story")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
span.SetAttributes(attribute.String("user.username", username))
// Parse story_id from path parameter
storyIDStr := c.Param("story_id")
storyID, err := strconv.ParseInt(storyIDStr, 10, 64)
if err != nil {
h.logger.Warn(ctx, "Invalid story_id parameter", map[string]any{
"story_id": storyIDStr,
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
span.SetAttributes(attribute.Int64("story.id", storyID))
// Get snippets for this story
snippets, err := h.snippetsService.GetSnippetsByStory(ctx, int64(userID), storyID)
if err != nil {
h.logger.Error(ctx, "Failed to get snippets by story", err, map[string]any{
"user_id": userID,
"story_id": storyID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get snippets by story"))
return
}
// Return response with snippets array
response := gin.H{
"snippets": snippets,
}
c.JSON(http.StatusOK, response)
}
// SearchSnippets handles GET /v1/snippets/search
func (h *SnippetsHandler) SearchSnippets(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "search_snippets")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
span.SetAttributes(attribute.String("user.username", username))
// Parse query parameters
query := c.Query("q")
if query == "" {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
// Optional filters
sourceLang := c.Query("source_lang")
limit, err := strconv.Atoi(limitStr)
if err != nil || limit < 1 {
limit = 20
}
if limit > 100 {
limit = 100
}
offset, err := strconv.Atoi(offsetStr)
if err != nil || offset < 0 {
offset = 0
}
span.SetAttributes(
attribute.String("query", query),
attribute.Int("limit", limit),
attribute.Int("offset", offset),
)
if sourceLang != "" {
span.SetAttributes(attribute.String("params.source_lang", sourceLang))
}
// Search snippets
var sourceLangPtr *string
if sourceLang != "" {
sourceLangPtr = &sourceLang
}
snippets, total, err := h.snippetsService.SearchSnippets(ctx, int64(userID), query, limit, offset, sourceLangPtr)
if err != nil {
h.logger.Error(ctx, "Failed to search snippets", err, map[string]any{
"user_id": userID,
"query": query,
"limit": limit,
"offset": offset,
})
HandleAppError(c, contextutils.WrapError(err, "failed to search snippets"))
return
}
// Add metadata to response
response := gin.H{
"snippets": snippets,
"query": query,
"total": total,
"limit": limit,
"offset": offset,
}
c.JSON(http.StatusOK, response)
}
// GetSnippet handles GET /v1/snippets/{id}
func (h *SnippetsHandler) GetSnippet(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "get_snippet")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.String("user.username", username))
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
// Parse snippet ID from URL parameter
snippetIDStr := c.Param("id")
snippetID, err := strconv.ParseInt(snippetIDStr, 10, 64)
if err != nil {
h.logger.Warn(ctx, "Invalid snippet ID format", map[string]interface{}{
"snippet_id": snippetIDStr,
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
snippet, err := h.snippetsService.GetSnippet(ctx, int64(userID), snippetID)
if err != nil {
h.logger.Error(ctx, "Failed to get snippet", err, map[string]interface{}{
"user_id": userID,
"snippet_id": snippetID,
})
HandleAppError(c, err)
return
}
// Convert to API response format
response := api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
}
span.SetAttributes(
attribute.Int64("snippet.id", snippet.ID),
attribute.Int64("user.id", int64(userID)),
attribute.String("user.username", username),
attribute.String("snippet.original_text", snippet.OriginalText),
attribute.String("snippet.translated_text", snippet.TranslatedText),
attribute.String("snippet.source_language", snippet.SourceLanguage),
attribute.String("snippet.target_language", snippet.TargetLanguage),
)
if snippet.QuestionID != nil {
span.SetAttributes(attribute.Int64("snippet.question_id", *snippet.QuestionID))
}
if snippet.Context != nil {
span.SetAttributes(attribute.String("snippet.context", *snippet.Context))
}
if snippet.DifficultyLevel != nil {
span.SetAttributes(attribute.String("snippet.difficulty_level", *snippet.DifficultyLevel))
}
c.JSON(http.StatusOK, response)
}
// UpdateSnippet handles PUT /v1/snippets/{id}
func (h *SnippetsHandler) UpdateSnippet(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "update_snippet")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.String("user.username", username))
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
// Parse snippet ID from URL parameter
snippetIDStr := c.Param("id")
snippetID, err := strconv.ParseInt(snippetIDStr, 10, 64)
if err != nil {
h.logger.Warn(ctx, "Invalid snippet ID format", map[string]interface{}{
"snippet_id": snippetIDStr,
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var req api.UpdateSnippetRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn(ctx, "Invalid update snippet request format", map[string]interface{}{
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
snippet, err := h.snippetsService.UpdateSnippet(ctx, int64(userID), snippetID, req)
if err != nil {
h.logger.Error(ctx, "Failed to update snippet", err, map[string]interface{}{
"user_id": userID,
"snippet_id": snippetID,
})
HandleAppError(c, err)
return
}
// Convert to API response format
response := api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
}
span.SetAttributes(
attribute.Int64("snippet.id", snippet.ID),
attribute.Int64("user.id", int64(userID)),
attribute.String("user.username", username),
attribute.String("snippet.original_text", snippet.OriginalText),
attribute.String("snippet.translated_text", snippet.TranslatedText),
attribute.String("snippet.source_language", snippet.SourceLanguage),
attribute.String("snippet.target_language", snippet.TargetLanguage),
)
if snippet.QuestionID != nil {
span.SetAttributes(attribute.Int64("snippet.question_id", *snippet.QuestionID))
}
if snippet.Context != nil {
span.SetAttributes(attribute.String("snippet.context", *snippet.Context))
}
if snippet.DifficultyLevel != nil {
span.SetAttributes(attribute.String("snippet.difficulty_level", *snippet.DifficultyLevel))
}
c.JSON(http.StatusOK, response)
}
// DeleteSnippet handles DELETE /v1/snippets/{id}
func (h *SnippetsHandler) DeleteSnippet(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "delete_snippet")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.String("user.username", username))
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
// Parse snippet ID from URL parameter
snippetIDStr := c.Param("id")
snippetID, err := strconv.ParseInt(snippetIDStr, 10, 64)
if err != nil {
h.logger.Warn(ctx, "Invalid snippet ID format", map[string]interface{}{
"snippet_id": snippetIDStr,
"error": err.Error(),
})
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
err = h.snippetsService.DeleteSnippet(ctx, int64(userID), snippetID)
if err != nil {
h.logger.Error(ctx, "Failed to delete snippet", err, map[string]interface{}{
"user_id": userID,
"snippet_id": snippetID,
})
HandleAppError(c, err)
return
}
span.SetAttributes(
attribute.Int64("snippet.id", snippetID),
attribute.Int64("user.id", int64(userID)),
attribute.String("user.username", username),
)
c.Status(http.StatusNoContent)
}
// DeleteAllSnippets handles DELETE /v1/snippets
func (h *SnippetsHandler) DeleteAllSnippets(c *gin.Context) {
ctx, span := observability.TraceSnippetFunction(c.Request.Context(), "delete_all_snippets")
defer observability.FinishSpan(span, nil)
// Get user ID from context (set by auth middleware)
userID, exists := GetUserIDFromSession(c)
if !exists {
h.logger.Warn(ctx, "User ID not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
username, exists := GetUsernameFromSession(c)
if !exists {
h.logger.Warn(ctx, "Username not found in context")
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(attribute.String("user.username", username))
span.SetAttributes(attribute.Int64("user.id", int64(userID)))
err := h.snippetsService.DeleteAllSnippets(ctx, int64(userID))
if err != nil {
h.logger.Error(ctx, "Failed to delete all snippets", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.ErrInternalError)
return
}
c.Status(http.StatusNoContent)
}
package handlers
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"github.com/jung-kurt/gofpdf"
"github.com/lib/pq"
"go.opentelemetry.io/otel/attribute"
)
// StoryHandler handles story-related HTTP requests
type StoryHandler struct {
storyService services.StoryServiceInterface
userService services.UserServiceInterface
aiService services.AIServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewStoryHandler creates a new StoryHandler
func NewStoryHandler(
storyService services.StoryServiceInterface,
userService services.UserServiceInterface,
aiService services.AIServiceInterface,
cfg *config.Config,
logger *observability.Logger,
) *StoryHandler {
return &StoryHandler{
storyService: storyService,
userService: userService,
aiService: aiService,
cfg: cfg,
logger: logger,
}
}
// CreateStory handles POST /v1/story
func (h *StoryHandler) CreateStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "create_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
// userID is already int from GetUserIDFromSession
var req models.CreateStoryRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error(ctx, "Failed to bind story creation request", err, nil)
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid request format", err.Error())
return
}
// Get user's language preference
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user", err, map[string]interface{}{
"user_id": userID,
})
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get user information", err.Error())
return
}
// Get the user's preferred language (handle sql.NullString)
language := "en" // default
if user.PreferredLanguage.Valid {
language = user.PreferredLanguage.String
}
story, err := h.storyService.CreateStory(ctx, uint(userID), language, &req)
if err != nil {
h.logger.Error(ctx, "Failed to create story", err, map[string]interface{}{
"user_id": userID,
"title": req.Title,
})
// Handle specific error cases
if strings.Contains(err.Error(), "maximum archived stories limit reached") {
StandardizeHTTPError(c, http.StatusForbidden, "Maximum archived stories limit reached", err.Error())
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to create story", err.Error())
return
}
span.SetAttributes(
attribute.String("story.title", story.Title),
attribute.Int("story.id", int(story.ID)),
attribute.String("user.language", language),
)
// Convert to API types to ensure proper serialization
apiStory := convertStoryToAPI(story)
c.JSON(http.StatusCreated, apiStory)
}
// GetUserStories handles GET /v1/story
func (h *StoryHandler) GetUserStories(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_stories")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
includeArchivedStr := c.Query("include_archived")
includeArchived := includeArchivedStr == "true"
stories, err := h.storyService.GetUserStories(ctx, uint(userID), includeArchived)
if err != nil {
h.logger.Error(ctx, "Failed to get user stories", err, map[string]interface{}{
"user_id": uint(userID),
"include_archived": includeArchived,
})
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get stories", err.Error())
return
}
c.JSON(http.StatusOK, stories)
}
// GetCurrentStory handles GET /v1/story/current
func (h *StoryHandler) GetCurrentStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_current_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
story, err := h.storyService.GetCurrentStory(ctx, uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to get current story", err, map[string]interface{}{
"user_id": uint(userID),
})
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get current story", err.Error())
return
}
if story == nil {
StandardizeHTTPError(c, http.StatusNotFound, "No current story found", "User has no active story")
return
}
// If story exists but has no sections, it's generating the first section
if len(story.Sections) == 0 {
c.JSON(http.StatusAccepted, api.GeneratingResponse{
Status: stringPtr("generating"),
Message: stringPtr("Story created successfully. The first section is being generated. Please check back shortly."),
})
return
}
// If story exists and has sections, show the story content
// The "generating" message should only appear when there are no sections at all
// (which is handled above) or when the system is actually generating a new section
// Record views for all sections in the story (user is accessing/reading them)
for _, section := range story.Sections {
if err := h.storyService.RecordStorySectionView(ctx, uint(userID), section.ID); err != nil {
h.logger.Warn(ctx, "Failed to record story section view", map[string]interface{}{
"user_id": userID,
"section_id": section.ID,
"story_id": story.ID,
"error": err.Error(),
})
// Don't fail the request if view recording fails
}
}
// Convert to API types to ensure proper serialization
apiStory := convertStoryWithSectionsToAPI(story)
c.JSON(http.StatusOK, apiStory)
}
// GetStory handles GET /v1/story/:id
func (h *StoryHandler) GetStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
story, err := h.storyService.GetStory(ctx, uint(storyID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to get story", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get story", err.Error())
return
}
// Record views for all sections in the story (user is accessing/reading them)
for _, section := range story.Sections {
if err := h.storyService.RecordStorySectionView(ctx, uint(userID), section.ID); err != nil {
h.logger.Warn(ctx, "Failed to record story section view", map[string]interface{}{
"user_id": userID,
"section_id": section.ID,
"story_id": storyID,
"error": err.Error(),
})
// Don't fail the request if view recording fails
}
}
// Convert to API types to ensure proper serialization
apiStory := convertStoryWithSectionsToAPI(story)
c.JSON(http.StatusOK, apiStory)
}
// GetSection handles GET /v1/story/section/:id
func (h *StoryHandler) GetSection(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_section")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
sectionIDStr := c.Param("id")
sectionID, err := strconv.ParseUint(sectionIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid section ID", "Section ID must be a valid number")
return
}
section, err := h.storyService.GetSection(ctx, uint(sectionID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to get section", err, map[string]interface{}{
"section_id": sectionID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Section not found", "The requested section does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get section", err.Error())
return
}
// Record view for this specific section (user is accessing/reading it)
if err := h.storyService.RecordStorySectionView(ctx, uint(userID), uint(sectionID)); err != nil {
h.logger.Warn(ctx, "Failed to record story section view", map[string]interface{}{
"user_id": userID,
"section_id": sectionID,
"error": err.Error(),
})
// Don't fail the request if view recording fails
}
// Convert to API types to ensure proper serialization
apiSection := convertStorySectionWithQuestionsToAPI(section)
c.JSON(http.StatusOK, apiSection)
}
// GenerateNextSection handles POST /v1/story/:id/generate
func (h *StoryHandler) GenerateNextSection(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "generate_next_section")
defer observability.FinishSpan(span, nil)
// Create a timeout context for story generation to prevent hanging requests
// Use the configured AI request timeout for consistency with other AI operations
timeoutCtx, cancel := context.WithTimeout(ctx, config.AIRequestTimeout)
defer cancel()
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
// Get user for AI config
user, err := h.userService.GetUserByID(timeoutCtx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for generation", err, map[string]interface{}{
"user_id": uint(userID),
})
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get user information", err.Error())
return
}
// Get user's AI configuration
userAIConfig, apiKeyID := h.convertToServicesAIConfig(timeoutCtx, user)
// Add user ID and API key ID to context for usage tracking
timeoutCtx = contextutils.WithUserID(timeoutCtx, userID)
if apiKeyID != nil {
timeoutCtx = contextutils.WithAPIKeyID(timeoutCtx, *apiKeyID)
}
// Generate the story section using the shared service method (user generation)
sectionWithQuestions, err := h.storyService.GenerateStorySection(timeoutCtx, uint(storyID), uint(userID), h.aiService, userAIConfig, models.GeneratorTypeUser)
if err != nil {
// Check if this is a generation limit reached error (normal business case)
if errors.Is(err, contextutils.ErrGenerationLimitReached) {
h.logger.Info(ctx, "User reached daily generation limit", map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
// Return 200 OK with business logic error instead of 409 Conflict
c.JSON(http.StatusOK, api.ErrorResponse{
Error: stringPtr("You have already generated a section today for this story. Please try again tomorrow."),
Details: stringPtr("daily generation limit reached"),
})
return
}
h.logger.Error(ctx, "Failed to generate story section", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
// Check if this is a constraint violation (duplicate generation today)
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "23505" {
StandardizeHTTPError(c, http.StatusConflict, "Cannot generate section", "You have already generated a section today for this story. Please try again tomorrow.")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to generate story section", err.Error())
return
}
// Return success response with the generated section
apiSection := convertStorySectionWithQuestionsToAPI(sectionWithQuestions)
c.JSON(http.StatusCreated, apiSection)
}
// ArchiveStory handles POST /v1/story/:id/archive
func (h *StoryHandler) ArchiveStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "archive_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
err = h.storyService.ArchiveStory(ctx, uint(storyID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to archive story", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to archive story", err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "story archived successfully"})
}
// CompleteStory handles POST /v1/story/:id/complete
func (h *StoryHandler) CompleteStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "complete_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
err = h.storyService.CompleteStory(ctx, uint(storyID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to complete story", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to complete story", err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "story completed successfully"})
}
// SetCurrentStory handles POST /v1/story/:id/set-current
func (h *StoryHandler) SetCurrentStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "set_current_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
err = h.storyService.SetCurrentStory(ctx, uint(storyID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to set current story", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to set current story", err.Error())
return
}
c.JSON(http.StatusOK, gin.H{"message": "story set as current successfully"})
}
// DeleteStory handles DELETE /v1/story/:id
func (h *StoryHandler) DeleteStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "delete_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
err = h.storyService.DeleteStory(ctx, uint(storyID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to delete story", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
if strings.Contains(err.Error(), "cannot delete active story") {
StandardizeHTTPError(c, http.StatusConflict, "Cannot delete active story", "You cannot delete a story that is currently active")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to delete story", err.Error())
return
}
c.JSON(http.StatusNoContent, nil)
}
// ToggleAutoGeneration handles POST /v1/story/:id/toggle-auto-generation
func (h *StoryHandler) ToggleAutoGeneration(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "toggle_auto_generation")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
// Parse request body to get the pause state
var req struct {
Paused *bool `json:"paused" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error(ctx, "Failed to bind toggle auto-generation request", err, nil)
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid request format", err.Error())
return
}
if req.Paused == nil {
h.logger.Error(ctx, "Missing paused field in toggle auto-generation request", nil, nil)
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid request format", "paused field is required")
return
}
err = h.storyService.ToggleAutoGeneration(ctx, uint(storyID), uint(userID), *req.Paused)
if err != nil {
h.logger.Error(ctx, "Failed to toggle auto-generation", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
"paused": *req.Paused,
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to toggle auto-generation", err.Error())
return
}
message := "Auto-generation resumed"
if *req.Paused {
message = "Auto-generation paused"
}
c.JSON(http.StatusOK, gin.H{"message": message, "auto_generation_paused": *req.Paused})
}
// ExportStory handles GET /v1/story/:id/export
func (h *StoryHandler) ExportStory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "export_story")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
StandardizeHTTPError(c, http.StatusUnauthorized, "Unauthorized", "User session not found or invalid")
return
}
storyIDStr := c.Param("id")
storyID, err := strconv.ParseUint(storyIDStr, 10, 32)
if err != nil {
StandardizeHTTPError(c, http.StatusBadRequest, "Invalid story ID", "Story ID must be a valid number")
return
}
// Get the story with all sections
story, err := h.storyService.GetStory(ctx, uint(storyID), uint(userID))
if err != nil {
h.logger.Error(ctx, "Failed to get story for export", err, map[string]interface{}{
"story_id": storyID,
"user_id": uint(userID),
})
if strings.Contains(err.Error(), "not found") || strings.Contains(err.Error(), "unauthorized") {
StandardizeHTTPError(c, http.StatusNotFound, "Story not found", "The requested story does not exist or you don't have access to it")
return
}
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to get story", err.Error())
return
}
// Create PDF
pdf := gofpdf.New("P", "mm", "A4", "")
// Use Arial (core font) for PDF generation
// Note: For proper Unicode support with non-Latin characters, we would need to:
// 1. Add a TTF font file (e.g., DejaVu Sans) to frontend/public/fonts/
// 2. Generate a .json font definition file using gofpdf's makefont utility
// 3. Register the font using pdf.AddUTF8Font()
// For now, Arial provides basic support and the buffer change prevents encoding issues
pdf.AddPage()
// Use Arial consistently; size will be overridden for headings where needed
pdf.SetFont("Arial", "B", 16)
// Add title
pdf.Cell(40, 10, story.Title)
pdf.Ln(12)
// Add story metadata if present
pdf.SetFont("Arial", "", 10)
if story.Subject != nil && *story.Subject != "" {
pdf.Cell(40, 8, fmt.Sprintf("Subject: %s", *story.Subject))
pdf.Ln(6)
}
if story.AuthorStyle != nil && *story.AuthorStyle != "" {
pdf.Cell(40, 8, fmt.Sprintf("Style: %s", *story.AuthorStyle))
pdf.Ln(6)
}
if story.Genre != nil && *story.Genre != "" {
pdf.Cell(40, 8, fmt.Sprintf("Genre: %s", *story.Genre))
pdf.Ln(6)
}
pdf.Ln(5)
// Add sections
pdf.SetFont("Arial", "", 11)
for _, section := range story.Sections {
// Section header
pdf.SetFont("Arial", "B", 12)
pdf.Cell(40, 8, fmt.Sprintf("Section %d", section.SectionNumber))
pdf.Ln(8)
// Section content
pdf.SetFont("Arial", "", 11)
// Split content into paragraphs (double line breaks)
paragraphs := strings.Split(section.Content, "\n\n")
for _, paragraph := range paragraphs {
if paragraph != "" {
// MultiCell for text wrapping
pdf.MultiCell(0, 6, paragraph, "", "L", false)
pdf.Ln(3)
}
}
pdf.Ln(5)
}
// Set headers for PDF download
filename := fmt.Sprintf("story_%s.pdf", strings.ReplaceAll(strings.ToLower(story.Title), " ", "_"))
c.Header("Content-Type", "application/pdf")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
var buf bytes.Buffer
err = pdf.Output(&buf)
if err != nil {
h.logger.Error(ctx, "Failed to generate PDF", err, map[string]interface{}{
"story_id": storyID,
})
StandardizeHTTPError(c, http.StatusInternalServerError, "Failed to generate PDF", err.Error())
return
}
c.Data(http.StatusOK, "application/pdf", buf.Bytes())
}
// convertToServicesAIConfig creates AI config for the user in services format
func (h *StoryHandler) convertToServicesAIConfig(ctx context.Context, user *models.User) (*models.UserAIConfig, *int) {
// Handle sql.NullString fields
aiProvider := ""
if user.AIProvider.Valid {
aiProvider = user.AIProvider.String
}
aiModel := ""
if user.AIModel.Valid {
aiModel = user.AIModel.String
}
apiKey := ""
var apiKeyID *int
if aiProvider != "" {
savedKey, keyID, err := h.userService.GetUserAPIKeyWithID(ctx, user.ID, aiProvider)
if err == nil && savedKey != "" {
apiKey = savedKey
apiKeyID = keyID
}
}
return &models.UserAIConfig{
Provider: aiProvider,
Model: aiModel,
APIKey: apiKey,
Username: user.Username,
}, apiKeyID
}
//go:build integration
package handlers
import (
"context"
"encoding/json"
"strings"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
)
// MockAIService implements AIServiceInterface for testing
type MockAIService struct {
realService *services.AIService
}
func NewMockAIService(cfg *config.Config, logger *observability.Logger) *MockAIService {
return &MockAIService{
realService: services.NewAIService(cfg, logger, services.NewNoopUsageStatsService()),
}
}
// TestConnection returns a mock response for AI connection tests
func (m *MockAIService) TestConnection(ctx context.Context, provider, model, apiKey string) error {
// For testing purposes, return success for valid-looking inputs
if provider != "" && model != "" {
// If it's a test API key, return an error to simulate failure
if strings.Contains(apiKey, "test") || apiKey == "" {
return contextutils.ErrorWithContextf("invalid API key")
}
return nil
}
return contextutils.ErrorWithContextf("missing provider or model")
}
// CallWithPrompt returns a mock response for AI fix requests, otherwise delegates to real service
func (m *MockAIService) CallWithPrompt(ctx context.Context, userConfig *models.UserAIConfig, prompt, grammar string) (string, error) {
// Check if this is an AI fix request by looking for fix-related keywords in the prompt
if strings.Contains(prompt, "fix") || strings.Contains(prompt, "Fix") ||
strings.Contains(prompt, "problematic") || strings.Contains(prompt, "report") {
// Return a mock AI fix response
mockResponse := map[string]interface{}{
"content": map[string]interface{}{
"question": "What is the capital of France?",
"options": []string{"Paris", "London", "Berlin", "Madrid"},
"correct_answer": 0,
"explanation": "Paris is the capital and largest city of France.",
},
"correct_answer": 0,
"explanation": "Paris is the capital and largest city of France.",
"change_reason": "Fixed grammar and improved clarity of the question.",
}
responseJSON, err := json.Marshal(mockResponse)
if err != nil {
return "", err
}
return string(responseJSON), nil
}
// For non-fix requests, delegate to the real service
if m.realService != nil {
return m.realService.CallWithPrompt(ctx, userConfig, prompt, grammar)
}
// Fallback response
return `{"response": "Mock AI response"}`, nil
}
// Implement other required methods by delegating to real service or returning defaults
func (m *MockAIService) GenerateQuestion(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest) (*models.Question, error) {
if m.realService != nil {
return m.realService.GenerateQuestion(ctx, userConfig, req)
}
return nil, contextutils.ErrorWithContextf("GenerateQuestion not implemented in mock")
}
func (m *MockAIService) GenerateQuestions(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest) ([]*models.Question, error) {
if m.realService != nil {
return m.realService.GenerateQuestions(ctx, userConfig, req)
}
return nil, contextutils.ErrorWithContextf("GenerateQuestions not implemented in mock")
}
func (m *MockAIService) GenerateQuestionsStream(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, variety *services.VarietyElements) error {
if m.realService != nil {
return m.realService.GenerateQuestionsStream(ctx, userConfig, req, progress, variety)
}
return contextutils.ErrorWithContextf("GenerateQuestionsStream not implemented in mock")
}
func (m *MockAIService) GenerateChatResponse(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIChatRequest) (string, error) {
if m.realService != nil {
return m.realService.GenerateChatResponse(ctx, userConfig, req)
}
return "Mock chat response", nil
}
func (m *MockAIService) GenerateChatResponseStream(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIChatRequest, chunks chan<- string) error {
if m.realService != nil {
return m.realService.GenerateChatResponseStream(ctx, userConfig, req, chunks)
}
select {
case chunks <- "Mock streaming response":
default:
}
return nil
}
func (m *MockAIService) GetConcurrencyStats() services.ConcurrencyStats {
if m.realService != nil {
return m.realService.GetConcurrencyStats()
}
return services.ConcurrencyStats{}
}
func (m *MockAIService) GetQuestionBatchSize(provider string) int {
if m.realService != nil {
return m.realService.GetQuestionBatchSize(provider)
}
return 1
}
func (m *MockAIService) VarietyService() *services.VarietyService {
if m.realService != nil {
return m.realService.VarietyService()
}
return nil
}
func (m *MockAIService) TemplateManager() *services.AITemplateManager {
if m.realService != nil {
return m.realService.TemplateManager()
}
return nil
}
func (m *MockAIService) GenerateStoryQuestions(ctx context.Context, userConfig *models.UserAIConfig, req *models.StoryQuestionsRequest) ([]*models.StorySectionQuestionData, error) {
if m.realService != nil {
return m.realService.GenerateStoryQuestions(ctx, userConfig, req)
}
// Return mock data for testing
return []*models.StorySectionQuestionData{
{
QuestionText: "What is the main character doing?",
Options: []string{"Reading", "Writing", "Running", "Swimming"},
CorrectAnswerIndex: 0,
Explanation: stringPtr("The main character is reading a book"),
},
}, nil
}
func (m *MockAIService) GenerateStorySection(ctx context.Context, userConfig *models.UserAIConfig, req *models.StoryGenerationRequest) (string, error) {
if m.realService != nil {
return m.realService.GenerateStorySection(ctx, userConfig, req)
}
// Return mock data for testing
return "Once upon a time, there was a brave knight who lived in a castle...", nil
}
func (m *MockAIService) SupportsGrammarField(provider string) bool {
if m.realService != nil {
return m.realService.SupportsGrammarField(provider)
}
return false
}
func (m *MockAIService) Shutdown(ctx context.Context) error {
if m.realService != nil {
return m.realService.Shutdown(ctx)
}
return nil
}
package handlers
import (
"context"
"net/http"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/serviceinterfaces"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// stringPtrOrEmpty returns the string value if not nil, otherwise returns empty string
func stringPtrOrEmpty(s *string) string {
if s == nil {
return ""
}
return *s
}
// TranslationHandler handles translation related HTTP requests
type TranslationHandler struct {
translationService services.TranslationServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewTranslationHandler creates a new TranslationHandler instance
func NewTranslationHandler(translationService services.TranslationServiceInterface, cfg *config.Config, logger *observability.Logger) *TranslationHandler {
return &TranslationHandler{
translationService: translationService,
cfg: cfg,
logger: logger,
}
}
// TranslateText handles text translation requests
func (h *TranslationHandler) TranslateText(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "translate_text")
defer observability.FinishSpan(span, nil)
var req api.TranslateRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Warn(ctx, "Invalid translation request format", map[string]interface{}{"error": err.Error()})
c.JSON(http.StatusBadRequest, api.ErrorResponse{
Code: stringPtr("INVALID_REQUEST"),
Message: stringPtr("Invalid request format"),
Error: stringPtr(err.Error()),
})
return
}
// Validate input
if err := h.validateTranslationRequest(ctx, req); err != nil {
h.logger.Warn(ctx, "Translation request validation failed", map[string]interface{}{"error": err.Error()})
c.JSON(http.StatusBadRequest, api.ErrorResponse{
Code: stringPtr("VALIDATION_ERROR"),
Message: stringPtr("Request validation failed"),
Error: stringPtr(err.Error()),
})
return
}
// Set span attributes for observability
span.SetAttributes(
attribute.String("translation.target_language", req.TargetLanguage),
attribute.String("translation.source_language", stringPtrOrEmpty(req.SourceLanguage)),
attribute.Int("translation.text_length", len(req.Text)),
)
// Perform translation
response, err := h.translationService.Translate(ctx, serviceinterfaces.TranslateRequest{
Text: req.Text,
TargetLanguage: req.TargetLanguage,
SourceLanguage: stringPtrOrEmpty(req.SourceLanguage),
})
if err != nil {
h.logger.Error(ctx, "Translation failed", err)
// Check if it's a service unavailable error
if contextutils.GetErrorCode(err) == contextutils.ErrorCodeServiceUnavailable {
c.JSON(http.StatusServiceUnavailable, api.ErrorResponse{
Code: stringPtr("TRANSLATION_SERVICE_UNAVAILABLE"),
Message: stringPtr("Translation service is currently unavailable"),
Error: stringPtr(err.Error()),
})
return
}
// Default to bad request for other errors
c.JSON(http.StatusBadRequest, api.ErrorResponse{
Code: stringPtr("TRANSLATION_FAILED"),
Message: stringPtr("Translation failed"),
Error: stringPtr(err.Error()),
})
return
}
// Return successful response
var confidencePtr *float32
if response.Confidence > 0 {
conf := float32(response.Confidence)
confidencePtr = &conf
}
c.JSON(http.StatusOK, api.TranslateResponse{
TranslatedText: response.TranslatedText,
SourceLanguage: response.SourceLanguage,
TargetLanguage: response.TargetLanguage,
Confidence: confidencePtr,
})
}
// validateTranslationRequest validates the translation request
func (h *TranslationHandler) validateTranslationRequest(_ context.Context, req api.TranslateRequest) error {
// Validate text length
if len(req.Text) == 0 {
return contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Text cannot be empty", "")
}
if len(req.Text) > 5000 {
return contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Text cannot exceed 5000 characters", "")
}
// Validate target language
if err := h.translationService.ValidateLanguageCode(req.TargetLanguage); err != nil {
return contextutils.WrapError(err, "Invalid target language")
}
// Validate source language if provided
if req.SourceLanguage != nil && *req.SourceLanguage != "" {
if err := h.translationService.ValidateLanguageCode(*req.SourceLanguage); err != nil {
return contextutils.WrapError(err, "Invalid source language")
}
}
return nil
}
// RegisterRoutes registers the translation routes with the router
func (h *TranslationHandler) RegisterRoutes(router *gin.Engine) {
v1 := router.Group("/v1")
{
v1.POST("/translate", middleware.RequireAuth(), h.TranslateText)
}
}
package handlers
import (
"context"
"database/sql"
"errors"
"fmt"
"html/template"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// UserAdminHandler handles user management operations
type UserAdminHandler struct {
userService services.UserServiceInterface
cfg *config.Config
templates *template.Template
logger *observability.Logger
}
// NewUserAdminHandler creates a new UserAdminHandler instance
func NewUserAdminHandler(userService services.UserServiceInterface, cfg *config.Config, logger *observability.Logger) *UserAdminHandler {
return &UserAdminHandler{
userService: userService,
cfg: cfg,
templates: nil,
logger: logger,
}
}
// UserCreateRequest represents a request to create a new user
// Using the generated type from api package for automatic validation
type UserCreateRequest = api.UserCreateRequest
// UserUpdateRequest represents a request to update user profile
// Using the generated type from api package for automatic validation
type UserUpdateRequest = api.UserUpdateRequest
// PasswordResetRequest represents a request to reset user password
// Using the generated type from api package for automatic validation
type PasswordResetRequest = api.PasswordResetRequest
// ProfileResponse represents user profile data
type ProfileResponse struct {
ID int `json:"id"`
Username string `json:"username"`
Email *string `json:"email"`
Timezone *string `json:"timezone"`
LastActive *time.Time `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
CurrentLevel *string `json:"current_level"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AIEnabled bool `json:"ai_enabled"`
AIProvider *string `json:"ai_provider"`
AIModel *string `json:"ai_model"`
Roles []models.Role `json:"roles,omitempty"`
IsPaused bool `json:"is_paused"`
}
// GetAllUsers handles GET /userz - list all users (admin only) - JSON API
func (h *UserAdminHandler) GetAllUsers(c *gin.Context) {
users, err := h.userService.GetAllUsers(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving users", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve users"))
return
}
// Convert to response format
var userResponses []ProfileResponse
for _, user := range users {
userResponses = append(userResponses, h.convertUserToProfileResponse(c.Request.Context(), &user))
}
c.JSON(http.StatusOK, gin.H{"users": userResponses})
}
// GetUsersPaginated handles GET /userz/paginated - list users with pagination (admin only)
func (h *UserAdminHandler) GetUsersPaginated(c *gin.Context) {
// Parse pagination parameters
page, pageSize := h.parsePagination(c)
// Parse filters
search := c.Query("search")
language := c.Query("language")
level := c.Query("level")
aiProvider := c.Query("ai_provider")
aiModel := c.Query("ai_model")
aiEnabled := c.Query("ai_enabled")
active := c.Query("active")
// Get paginated users from service
var users []models.User
var total int
var err error
users, total, err = h.userService.GetUsersPaginated(
c.Request.Context(),
page,
pageSize,
search,
language,
level,
aiProvider,
aiModel,
aiEnabled,
active,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving paginated users", err, map[string]interface{}{
"page": page,
"page_size": pageSize,
"search": search,
})
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve users"))
return
}
// Convert to response format matching swagger specification
var userItems []gin.H
for _, user := range users {
profileResponse := h.convertUserToProfileResponse(c.Request.Context(), &user)
// Create user item with nested user object as per swagger spec
userItem := gin.H{
"user": profileResponse,
}
userItems = append(userItems, userItem)
}
// Calculate pagination info
totalPages := (total + pageSize - 1) / pageSize
c.JSON(http.StatusOK, gin.H{
"users": userItems,
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": totalPages,
},
})
}
// parsePagination parses pagination parameters from the request
func (h *UserAdminHandler) parsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if pageStr := c.Query("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
pageSize = ps
}
}
return page, pageSize
}
// CreateUser handles POST /userz - create new user (admin only)
func (h *UserAdminHandler) CreateUser(c *gin.Context) {
var req UserCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate required fields
if req.Username == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
if req.Password == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Extract values from generated types
timezone := "UTC"
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
// Validate timezone if provided
if !h.isValidTimezone(timezone) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
preferredLanguage := "italian"
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
preferredLanguage = *req.PreferredLanguage
}
currentLevel := "A1"
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
currentLevel = *req.CurrentLevel
}
email := ""
if req.Email != nil {
email = string(*req.Email)
}
// Check if username already exists
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing username", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check existing username"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
// Check if email already exists (if provided)
if email != "" {
existingUser, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing email", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check email uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Create user
user, err := h.userService.CreateUserWithEmailAndTimezone(
c.Request.Context(),
req.Username,
email,
timezone,
preferredLanguage,
currentLevel,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Error creating user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to create user"))
return
}
// Set password
err = h.userService.UpdateUserPassword(c.Request.Context(), user.ID, req.Password)
if err != nil {
h.logger.Error(c.Request.Context(), "Error setting user password", err, nil)
// Try to clean up the created user
_ = h.userService.DeleteUser(c.Request.Context(), user.ID)
HandleAppError(c, contextutils.WrapError(err, "failed to set user password"))
return
}
// Return the created user profile
c.JSON(http.StatusCreated, gin.H{
"message": "User created successfully",
"user": h.convertUserToProfileResponse(c.Request.Context(), user),
})
}
// UpdateUser handles PUT /userz/:id - update user details (admin or self)
func (h *UserAdminHandler) UpdateUser(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check authorization (admin or self) - skip for direct routes (testing)
if currentUserID, err := GetCurrentUserID(c); err == nil {
if err := RequireSelfOrAdmin(c.Request.Context(), h.userService, currentUserID, userID); err != nil {
if contextutils.IsError(err, contextutils.ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(c.Request.Context(), "Error checking authorization", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
}
var req UserUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate timezone if provided
if req.Timezone != nil && *req.Timezone != "" && !h.isValidTimezone(*req.Timezone) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Use existing values if not provided in request
username := user.Username
if req.Username != nil && *req.Username != "" {
username = *req.Username
}
email := ""
if user.Email.Valid {
email = user.Email.String
}
if req.Email != nil {
email = string(*req.Email)
}
timezone := ""
if user.Timezone.Valid {
timezone = user.Timezone.String
}
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
}
preferredLanguage := ""
if user.PreferredLanguage.Valid {
preferredLanguage = user.PreferredLanguage.String
}
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
preferredLanguage = *req.PreferredLanguage
}
currentLevel := ""
if user.CurrentLevel.Valid {
currentLevel = user.CurrentLevel.String
}
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
currentLevel = *req.CurrentLevel
}
// Check if new username already exists (if changed)
if username != user.Username {
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing username", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check username uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Check if new email already exists (if changed)
if email != "" && user.Email.Valid && email != user.Email.String {
existingUser, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing email", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check email uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Update user profile
err = h.userService.UpdateUserProfile(c.Request.Context(), userID, username, email, timezone)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user profile", err, nil)
// Check if the error is due to user not found
if errors.Is(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update user profile"))
return
}
// Handle AI settings update if provided
needsAIUpdate := req.AiEnabled != nil || (req.AiProvider != nil && *req.AiProvider != "") || (req.AiModel != nil && *req.AiModel != "") || (req.ApiKey != nil && *req.ApiKey != "")
if needsAIUpdate {
// Prepare AI settings
aiSettings := &models.UserSettings{
Language: preferredLanguage,
Level: currentLevel,
AIEnabled: req.AiEnabled != nil && *req.AiEnabled,
}
// Set AI provider and model
if req.AiProvider != nil && *req.AiProvider != "" {
aiSettings.AIProvider = *req.AiProvider
} else if user.AIProvider.Valid {
aiSettings.AIProvider = user.AIProvider.String
}
if req.AiModel != nil && *req.AiModel != "" {
aiSettings.AIModel = *req.AiModel
} else if user.AIModel.Valid {
aiSettings.AIModel = user.AIModel.String
}
// Set API key if provided
if req.ApiKey != nil && *req.ApiKey != "" {
aiSettings.AIAPIKey = *req.ApiKey
}
// Update AI settings
err = h.userService.UpdateUserSettings(c.Request.Context(), userID, aiSettings)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user AI settings", err, nil)
// Check if the error is due to user not found
if errors.Is(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update AI settings"))
return
}
}
// Handle role updates if provided
if req.SelectedRoles != nil {
// Get current user roles
currentRoles, err := h.userService.GetUserRoles(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting current user roles", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to get current user roles"))
return
}
// Get all available roles
allRoles, err := h.userService.GetAllRoles(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting all roles", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to get available roles"))
return
}
// Create maps for efficient lookup
currentRoleNames := make(map[string]bool)
for _, role := range currentRoles {
currentRoleNames[role.Name] = true
}
requestedRoleNames := make(map[string]bool)
for _, roleName := range *req.SelectedRoles {
requestedRoleNames[roleName] = true
}
// Find roles to add and remove
for _, roleName := range *req.SelectedRoles {
if !currentRoleNames[roleName] {
// Find role by name
var roleToAdd *models.Role
for _, role := range allRoles {
if role.Name == roleName {
roleToAdd = &role
break
}
}
if roleToAdd != nil {
err = h.userService.AssignRole(c.Request.Context(), userID, roleToAdd.ID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error assigning role to user", err, map[string]interface{}{
"user_id": userID,
"role_id": roleToAdd.ID,
"role_name": roleName,
})
HandleAppError(c, contextutils.WrapError(err, "failed to assign role"))
return
}
}
}
}
// Remove roles that are no longer selected
for _, role := range currentRoles {
if !requestedRoleNames[role.Name] {
err = h.userService.RemoveRole(c.Request.Context(), userID, role.ID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error removing role from user", err, map[string]interface{}{
"user_id": userID,
"role_id": role.ID,
"role_name": role.Name,
})
HandleAppError(c, contextutils.WrapError(err, "failed to remove role"))
return
}
}
}
}
// Get updated user
updatedUser, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving updated user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve updated user"))
return
}
c.JSON(http.StatusOK, gin.H{
"message": "User updated successfully",
"user": h.convertUserToProfileResponse(c.Request.Context(), updatedUser),
})
}
// DeleteUser handles DELETE /userz/:id - delete user (admin only)
func (h *UserAdminHandler) DeleteUser(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Delete user
err = h.userService.DeleteUser(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error deleting user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to delete user"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "User deleted successfully"})
}
// ResetUserPassword handles POST /userz/:id/reset-password - reset user password (admin only)
func (h *UserAdminHandler) ResetUserPassword(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
h.logger.Warn(c.Request.Context(), "User not found for password reset", map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
var req PasswordResetRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error(c.Request.Context(), "Invalid request data for password reset", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate password
if req.NewPassword == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Update password
err = h.userService.UpdateUserPassword(c.Request.Context(), userID, req.NewPassword)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user password", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to update password"))
return
}
h.logger.Info(c.Request.Context(), "Password reset successful", map[string]interface{}{"user_id": userID, "username": user.Username})
c.JSON(http.StatusOK, gin.H{"message": "Password reset successfully"})
}
// UpdateCurrentUserProfile handles PUT /userz/profile - update current user profile
func (h *UserAdminHandler) UpdateCurrentUserProfile(c *gin.Context) {
// Get user ID from context/session
userID, err := GetCurrentUserID(c)
if err != nil {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req UserUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate timezone if provided
if req.Timezone != nil && *req.Timezone != "" && !h.isValidTimezone(*req.Timezone) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Email validation is handled automatically by openapi_types.Email
// Get current user
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check authorization (self-only for this endpoint)
if err := RequireSelfOrAdmin(c.Request.Context(), h.userService, userID, userID); err != nil {
if contextutils.IsError(err, contextutils.ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(c.Request.Context(), "Error checking authorization", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
// Use existing values if not provided in request
username := user.Username
if req.Username != nil && *req.Username != "" {
username = *req.Username
}
email := ""
if user.Email.Valid {
email = user.Email.String
}
if req.Email != nil {
email = string(*req.Email)
}
timezone := ""
if user.Timezone.Valid {
timezone = user.Timezone.String
}
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
}
// Check if new username already exists (if changed)
if username != user.Username {
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing username", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check username uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Check if new email already exists (if changed)
if email != "" && user.Email.Valid && email != user.Email.String {
existingUser, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing email", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check email uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Use existing AI values if not provided in request
preferredLanguage := ""
if user.PreferredLanguage.Valid {
preferredLanguage = user.PreferredLanguage.String
}
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
preferredLanguage = *req.PreferredLanguage
}
currentLevel := ""
if user.CurrentLevel.Valid {
currentLevel = user.CurrentLevel.String
}
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
currentLevel = *req.CurrentLevel
}
// Update user profile
err = h.userService.UpdateUserProfile(c.Request.Context(), userID, username, email, timezone)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user profile", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to update user profile"))
return
}
// Handle AI settings update if provided
needsAIUpdate := req.AiEnabled != nil || (req.AiProvider != nil && *req.AiProvider != "") || (req.AiModel != nil && *req.AiModel != "") || (req.PreferredLanguage != nil && *req.PreferredLanguage != "") || (req.CurrentLevel != nil && *req.CurrentLevel != "") || (req.ApiKey != nil && *req.ApiKey != "")
if needsAIUpdate {
aiSettings := &models.UserSettings{
Language: preferredLanguage,
Level: currentLevel,
AIEnabled: req.AiEnabled != nil && *req.AiEnabled,
}
if req.AiProvider != nil && *req.AiProvider != "" {
aiSettings.AIProvider = *req.AiProvider
} else if user.AIProvider.Valid {
aiSettings.AIProvider = user.AIProvider.String
}
if req.AiModel != nil && *req.AiModel != "" {
aiSettings.AIModel = *req.AiModel
} else if user.AIModel.Valid {
aiSettings.AIModel = user.AIModel.String
}
if req.ApiKey != nil && *req.ApiKey != "" {
aiSettings.AIAPIKey = *req.ApiKey
}
err = h.userService.UpdateUserSettings(c.Request.Context(), userID, aiSettings)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user AI settings", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to update AI settings"))
return
}
}
// Get updated user
updatedUser, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving updated user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve updated profile"))
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Profile updated successfully",
"user": h.convertUserToProfileResponse(c.Request.Context(), updatedUser),
})
}
// isUserPaused checks if a user is paused by checking the worker_settings table
func (h *UserAdminHandler) isUserPaused(ctx context.Context, userID int) bool {
query := `SELECT setting_value FROM worker_settings WHERE setting_key = $1`
var value string
settingKey := fmt.Sprintf("user_pause_%d", userID)
err := h.userService.GetDB().QueryRowContext(ctx, query, settingKey).Scan(&value)
if err != nil {
// If no setting exists, user is not paused
if errors.Is(err, sql.ErrNoRows) {
return false
}
// Log error but don't fail - default to not paused
h.logger.Warn(ctx, "Failed to check user pause status", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
return false
}
return value == "true"
}
// Helper functions
// convertUserToProfileResponse converts a User model to ProfileResponse
func (h *UserAdminHandler) convertUserToProfileResponse(ctx context.Context, user *models.User) ProfileResponse {
// Get user roles
roles, err := h.userService.GetUserRoles(ctx, user.ID)
if err != nil {
// Log error but don't fail the response
h.logger.Warn(ctx, "Failed to get user roles", map[string]interface{}{
"user_id": user.ID,
"error": err.Error(),
})
roles = []models.Role{}
}
return ProfileResponse{
ID: user.ID,
Username: user.Username,
Email: nullStringToPointer(user.Email),
Timezone: nullStringToPointer(user.Timezone),
LastActive: nullTimeToPointer(user.LastActive),
PreferredLanguage: nullStringToPointer(user.PreferredLanguage),
CurrentLevel: nullStringToPointer(user.CurrentLevel),
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
AIEnabled: user.AIEnabled.Valid && user.AIEnabled.Bool,
AIProvider: nullStringToPointer(user.AIProvider),
AIModel: nullStringToPointer(user.AIModel),
Roles: roles,
IsPaused: h.isUserPaused(ctx, user.ID),
}
}
// isValidTimezone checks if a timezone string is valid
func (h *UserAdminHandler) isValidTimezone(tz string) bool {
// Common timezone validation - check if it can be loaded
_, err := time.LoadLocation(tz)
if err != nil {
// Also allow UTC as fallback
return strings.ToUpper(tz) == "UTC"
}
return true
}
// Helper function to convert sql.NullString to *string (if not already available)
func nullStringToPointer(ns sql.NullString) *string {
if ns.Valid {
return &ns.String
}
return nil
}
// Helper function to convert sql.NullTime to *time.Time (if not already available)
func nullTimeToPointer(nt sql.NullTime) *time.Time {
if nt.Valid {
return &nt.Time
}
return nil
}
package handlers
import (
"embed"
"encoding/json"
"fmt"
"net/http"
"strings"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
//go:embed data/verb-conjugations
var verbConjugationFS embed.FS
// VerbConjugationHandler handles verb conjugation related HTTP requests
type VerbConjugationHandler struct {
logger *observability.Logger
}
// NewVerbConjugationHandler creates a new VerbConjugationHandler instance
func NewVerbConjugationHandler(logger *observability.Logger) *VerbConjugationHandler {
return &VerbConjugationHandler{
logger: logger,
}
}
// VerbConjugationData represents the complete verb conjugation data for a language
type VerbConjugationData struct {
Language string `json:"language"`
LanguageName string `json:"languageName"`
Verbs []VerbConjugation `json:"verbs"`
}
// VerbConjugation represents a single verb with its conjugations across all tenses
type VerbConjugation struct {
Language string `json:"language"`
LanguageName string `json:"languageName"`
Infinitive string `json:"infinitive"`
InfinitiveEn string `json:"infinitiveEn"`
Slug string `json:"slug,omitempty"` // Optional ASCII slug for filename when infinitive has Unicode combining characters
Category string `json:"category"`
Tenses []Tense `json:"tenses"`
}
// Tense represents a grammatical tense with its conjugations and description
type Tense struct {
TenseID string `json:"tenseId"`
TenseName string `json:"tenseName"`
TenseNameEn string `json:"tenseNameEn"`
Description string `json:"description"`
Conjugations []Conjugation `json:"conjugations"`
}
// Conjugation represents a single conjugated form with example sentence
type Conjugation struct {
Pronoun string `json:"pronoun"`
Form string `json:"form"`
ExampleSentence string `json:"exampleSentence"`
ExampleSentenceEn string `json:"exampleSentenceEn"`
}
// VerbConjugationInfo represents metadata about the verb conjugation section
type VerbConjugationInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Emoji string `json:"emoji"`
Description string `json:"description"`
}
// GetVerbConjugationInfo returns metadata about verb conjugations
func (h *VerbConjugationHandler) GetVerbConjugationInfo(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_verb_conjugation_info")
defer observability.FinishSpan(span, nil)
data, err := verbConjugationFS.ReadFile("data/verb-conjugations/info.json")
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to read verb conjugation info", err)
HandleAppError(c, contextutils.WrapError(err, "failed to read verb conjugation info"))
return
}
var info VerbConjugationInfo
if err := json.Unmarshal(data, &info); err != nil {
h.logger.Error(c.Request.Context(), "Failed to parse verb conjugation info", err)
HandleAppError(c, contextutils.WrapError(err, "failed to parse verb conjugation info"))
return
}
c.JSON(http.StatusOK, info)
}
// GetVerbConjugations returns all verbs for a specific language
func (h *VerbConjugationHandler) GetVerbConjugations(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_verb_conjugations")
defer observability.FinishSpan(span, nil)
languageCode := c.Param("language")
if languageCode == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
span.SetAttributes(attribute.String("language", languageCode))
// Read all verb files in the language directory
languageDir := fmt.Sprintf("data/verb-conjugations/%s", languageCode)
entries, err := verbConjugationFS.ReadDir(languageDir)
if err != nil {
// Check if it's a directory not found error
if strings.Contains(err.Error(), "file does not exist") || strings.Contains(err.Error(), "no such file") || strings.Contains(err.Error(), "not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
h.logger.Error(c.Request.Context(), "Failed to read verb conjugation directory", err, map[string]interface{}{
"language": languageCode,
"directory": languageDir,
})
HandleAppError(c, contextutils.WrapError(err, "failed to read verb conjugation directory"))
return
}
var verbs []VerbConjugation
var languageName string
var language string
// Read each verb file
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".json") {
filename := fmt.Sprintf("%s/%s", languageDir, entry.Name())
data, err := verbConjugationFS.ReadFile(filename)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to read verb file", err, map[string]interface{}{
"language": languageCode,
"filename": filename,
})
HandleAppError(c, contextutils.WrapError(err, "failed to read verb file"))
return
}
var verb VerbConjugation
if err := json.Unmarshal(data, &verb); err != nil {
h.logger.Error(c.Request.Context(), "Failed to parse verb file", err, map[string]interface{}{
"language": languageCode,
"filename": filename,
})
HandleAppError(c, contextutils.WrapError(err, "failed to parse verb file"))
return
}
// Set language metadata from first verb (all verbs in a directory should have the same language)
if languageName == "" {
languageName = verb.LanguageName
language = verb.Language
}
verbs = append(verbs, verb)
}
}
if len(verbs) == 0 {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
verbData := VerbConjugationData{
Language: language,
LanguageName: languageName,
Verbs: verbs,
}
c.JSON(http.StatusOK, verbData)
}
// GetVerbConjugation returns a specific verb's conjugations for a language
func (h *VerbConjugationHandler) GetVerbConjugation(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_verb_conjugation")
defer observability.FinishSpan(span, nil)
languageCode := c.Param("language")
verbInfinitive := c.Param("verb")
if languageCode == "" || verbInfinitive == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
span.SetAttributes(attribute.String("language", languageCode))
span.SetAttributes(attribute.String("verb", verbInfinitive))
// Read the specific verb file
filename := fmt.Sprintf("data/verb-conjugations/%s/%s.json", languageCode, verbInfinitive)
data, err := verbConjugationFS.ReadFile(filename)
if err != nil {
// Check if it's a file not found error
if strings.Contains(err.Error(), "file does not exist") || strings.Contains(err.Error(), "no such file") || strings.Contains(err.Error(), "not found") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
h.logger.Error(c.Request.Context(), "Failed to read verb file", err, map[string]interface{}{
"language": languageCode,
"verb": verbInfinitive,
"filename": filename,
})
HandleAppError(c, contextutils.WrapError(err, "failed to read verb file"))
return
}
var verb VerbConjugation
if err := json.Unmarshal(data, &verb); err != nil {
h.logger.Error(c.Request.Context(), "Failed to parse verb file", err, map[string]interface{}{
"language": languageCode,
"verb": verbInfinitive,
})
HandleAppError(c, contextutils.WrapError(err, "failed to parse verb file"))
return
}
c.JSON(http.StatusOK, verb)
}
// GetAvailableLanguages returns the list of available languages for verb conjugations
func (h *VerbConjugationHandler) GetAvailableLanguages(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_available_languages")
defer observability.FinishSpan(span, nil)
// Read all entries in the verb-conjugations directory
entries, err := verbConjugationFS.ReadDir("data/verb-conjugations")
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to read verb conjugation directory", err)
HandleAppError(c, contextutils.WrapError(err, "failed to read verb conjugation directory"))
return
}
var languages []string
for _, entry := range entries {
// Only include directories (language folders), skip files like info.json
if entry.IsDir() {
languages = append(languages, entry.Name())
}
}
c.JSON(http.StatusOK, languages)
}
package handlers
import (
"context"
"fmt"
"html/template"
"net/http"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// WordOfTheDayHandler handles word of the day HTTP requests
type WordOfTheDayHandler struct {
userService services.UserServiceInterface
wordOfTheDayService services.WordOfTheDayServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewWordOfTheDayHandler creates a new WordOfTheDayHandler
func NewWordOfTheDayHandler(
userService services.UserServiceInterface,
wordOfTheDayService services.WordOfTheDayServiceInterface,
cfg *config.Config,
logger *observability.Logger,
) *WordOfTheDayHandler {
return &WordOfTheDayHandler{
userService: userService,
wordOfTheDayService: wordOfTheDayService,
cfg: cfg,
logger: logger,
}
}
// ParseDateInUserTimezone parses a date string in the user's timezone
func (h *WordOfTheDayHandler) ParseDateInUserTimezone(ctx context.Context, userID int, dateStr string) (time.Time, string, error) {
// Delegate to shared util with injected user lookup
return contextutils.ParseDateInUserTimezone(ctx, userID, dateStr, h.userService.GetUserByID)
}
// GetWordOfTheDay handles GET /v1/word-of-day/:date
func (h *WordOfTheDayHandler) GetWordOfTheDay(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_word_of_the_day")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse date parameter
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.String("timezone", timezone),
)
// Get word of the day
word, err := h.wordOfTheDayService.GetWordOfTheDay(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get word of the day", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get word of the day"))
return
}
c.JSON(http.StatusOK, word)
}
// GetWordOfTheDayToday handles GET /v1/word-of-day
// It resolves "today" in the user's timezone and returns that day's word
func (h *WordOfTheDayHandler) GetWordOfTheDayToday(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_word_of_the_day_today")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Determine today's date string and parse it in user's timezone
todayStr := time.Now().Format("2006-01-02")
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, todayStr)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to resolve today's date"))
return
}
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", todayStr),
attribute.String("timezone", timezone),
)
// Get word of the day
word, err := h.wordOfTheDayService.GetWordOfTheDay(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get today's word of the day", err, map[string]interface{}{
"user_id": userID,
"date": todayStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get word of the day"))
return
}
c.JSON(http.StatusOK, word)
}
// GetWordOfTheDayEmbed handles GET /v1/word-of-day/:date/embed
// This endpoint returns HTML for embedding in an iframe. Requires an authenticated session.
func (h *WordOfTheDayHandler) GetWordOfTheDayEmbed(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_word_of_the_day_embed")
defer observability.FinishSpan(span, nil)
// Determine user via session; no query parameters are supported
userID, exists := GetUserIDFromSession(c)
if !exists {
c.Data(http.StatusUnauthorized, "text/html; charset=utf-8", []byte("Unauthorized"))
return
}
// Resolve date parameter from path, query, or default to today's date
dateStr := c.Param("date")
if dateStr == "" {
dateStr = c.Query("date")
}
if dateStr == "" {
dateStr = time.Now().Format("2006-01-02")
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
c.Data(http.StatusBadRequest, "text/html; charset=utf-8", []byte("Invalid date format"))
return
}
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.String("timezone", timezone),
)
// Get word of the day
word, err := h.wordOfTheDayService.GetWordOfTheDay(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get word of the day for embed", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to load word of the day"))
return
}
// Render HTML template
html := h.renderEmbedHTML(word)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(html))
}
// GetWordOfTheDayHistory handles GET /v1/word-of-day/history
func (h *WordOfTheDayHandler) GetWordOfTheDayHistory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_word_of_the_day_history")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse date range parameters
startDateStr := c.Query("start_date")
endDateStr := c.Query("end_date")
if startDateStr == "" || endDateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse dates in user's timezone
startDate, _, err := h.ParseDateInUserTimezone(ctx, userID, startDateStr)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "invalid start_date"))
return
}
endDate, _, err := h.ParseDateInUserTimezone(ctx, userID, endDateStr)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "invalid end_date"))
return
}
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("start_date", startDateStr),
attribute.String("end_date", endDateStr),
)
// Get word history
words, err := h.wordOfTheDayService.GetWordHistory(ctx, userID, startDate, endDate)
if err != nil {
h.logger.Error(ctx, "Failed to get word of the day history", err, map[string]interface{}{
"user_id": userID,
"start_date": startDateStr,
"end_date": endDateStr,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get word history"))
return
}
c.JSON(http.StatusOK, gin.H{
"words": words,
"count": len(words),
})
}
// renderEmbedHTML renders the embed HTML template
func (h *WordOfTheDayHandler) renderEmbedHTML(word *models.WordOfTheDayDisplay) string {
if word == nil {
// Gracefully handle missing word to avoid panics in tests/environments with no data
return "<html><head><meta charset=\"UTF-8\"></head><body>Word of the Day is unavailable.</body></html>"
}
const embedTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Word of the Day</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: #333;
padding: 20px;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}
.card {
background: white;
border-radius: 16px;
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2);
padding: 30px;
max-width: 500px;
width: 100%;
}
.date {
color: #667eea;
font-size: 14px;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 1px;
margin-bottom: 10px;
}
.word {
font-size: 48px;
font-weight: bold;
color: #1a1a1a;
margin-bottom: 10px;
line-height: 1.2;
}
.translation {
font-size: 24px;
color: #667eea;
margin-bottom: 20px;
font-style: italic;
}
.sentence {
font-size: 18px;
line-height: 1.6;
color: #555;
background: #f7f7f7;
padding: 20px;
border-radius: 8px;
border-left: 4px solid #667eea;
margin-bottom: 15px;
}
.meta {
display: flex;
gap: 10px;
flex-wrap: wrap;
margin-top: 20px;
}
.badge {
background: #e0e7ff;
color: #667eea;
padding: 6px 12px;
border-radius: 20px;
font-size: 12px;
font-weight: 600;
}
.explanation {
font-size: 14px;
color: #666;
margin-top: 15px;
padding: 15px;
background: #fafafa;
border-radius: 8px;
border-left: 3px solid #764ba2;
}
</style>
</head>
<body>
<div class="card">
<div class="date">{{.FormattedDate}}</div>
<div class="word">{{.Word}}</div>
<div class="translation">{{.Translation}}</div>
{{if .Sentence}}
<div class="sentence">{{.Sentence}}</div>
{{end}}
<div class="meta">
{{if .Language}}<span class="badge">{{.Language}}</span>{{end}}
{{if .Level}}<span class="badge">{{.Level}}</span>{{end}}
{{if .TopicCategory}}<span class="badge">{{.TopicCategory}}</span>{{end}}
</div>
{{if .Explanation}}
<div class="explanation">{{.Explanation}}</div>
{{end}}
</div>
</body>
</html>
`
tmpl, err := template.New("embed").Parse(embedTemplate)
if err != nil {
return fmt.Sprintf("<html><body>Error rendering template: %v</body></html>", err)
}
data := struct {
FormattedDate string
Word string
Translation string
Sentence string
Language string
Level string
TopicCategory string
Explanation string
}{
FormattedDate: word.Date.Format("January 2, 2006"),
Word: word.Word,
Translation: word.Translation,
Sentence: word.Sentence,
Language: word.Language,
Level: word.Level,
TopicCategory: word.TopicCategory,
Explanation: word.Explanation,
}
var buf strings.Builder
if err := tmpl.Execute(&buf, data); err != nil {
return fmt.Sprintf("<html><body>Error executing template: %v</body></html>", err)
}
return buf.String()
}
package handlers
import (
"errors"
"fmt"
"html/template"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"quizapp/internal/worker"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// WorkerAdminHandler handles worker administration endpoints
type WorkerAdminHandler struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
config *config.Config
worker *worker.Worker
workerService services.WorkerServiceInterface
templates *template.Template
learningService services.LearningServiceInterface
dailyQuestionService services.DailyQuestionServiceInterface
logger *observability.Logger
}
// NewWorkerAdminHandlerWithLogger creates a new WorkerAdminHandler
func NewWorkerAdminHandlerWithLogger(
userService services.UserServiceInterface,
questionService services.QuestionServiceInterface,
aiService services.AIServiceInterface,
cfg *config.Config,
worker *worker.Worker,
workerService services.WorkerServiceInterface,
learningService services.LearningServiceInterface,
dailyQuestionService services.DailyQuestionServiceInterface,
logger *observability.Logger,
) *WorkerAdminHandler {
return &WorkerAdminHandler{
userService: userService,
questionService: questionService,
aiService: aiService,
config: cfg,
worker: worker,
workerService: workerService,
templates: nil,
learningService: learningService,
dailyQuestionService: dailyQuestionService,
logger: logger,
}
}
// GetWorkerDetails returns detailed worker information
func (h *WorkerAdminHandler) GetWorkerDetails(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_details")
defer span.End()
// Get worker status from local instance if available
var localStatus worker.Status
var localHistory []worker.RunRecord
if h.worker != nil {
localStatus = h.worker.GetStatus()
localHistory = h.worker.GetHistory()
}
// Get global pause status
globalPaused, err := h.workerService.IsGlobalPaused(ctx)
if err != nil {
// Log the error but continue with default value
h.logger.Warn(ctx, "Failed to get global pause status", map[string]interface{}{"error": err.Error()})
globalPaused = false
}
response := gin.H{
"status": localStatus,
"history": localHistory,
"global_paused": globalPaused,
}
c.JSON(http.StatusOK, response)
}
// GetActivityLogs returns recent activity logs from the worker
func (h *WorkerAdminHandler) GetActivityLogs(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_activity_logs")
defer span.End()
if h.worker == nil {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
logs := h.worker.GetActivityLogs()
c.JSON(http.StatusOK, gin.H{"logs": logs})
}
// PauseWorker pauses the worker globally
func (h *WorkerAdminHandler) PauseWorker(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "pause_worker")
defer span.End()
if err := h.workerService.SetGlobalPause(ctx, true); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to pause worker globally"))
return
}
// Also pause the local worker instance if available
if h.worker != nil {
h.worker.Pause(ctx)
}
c.JSON(http.StatusOK, gin.H{"message": "Worker paused globally"})
}
// ResumeWorker resumes the worker globally
func (h *WorkerAdminHandler) ResumeWorker(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "resume_worker")
defer span.End()
if err := h.workerService.SetGlobalPause(ctx, false); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to resume worker globally"))
return
}
// Also resume the local worker instance if available
if h.worker != nil {
h.worker.Resume(ctx)
}
c.JSON(http.StatusOK, gin.H{"message": "Worker resumed globally"})
}
// GetWorkerStatus returns current worker status
func (h *WorkerAdminHandler) GetWorkerStatus(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_status")
defer span.End()
instance := c.DefaultQuery("instance", "default")
status, err := h.workerService.GetWorkerStatus(ctx, instance)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get worker status"))
return
}
c.JSON(http.StatusOK, status)
}
// TriggerWorkerRun triggers a manual worker run
func (h *WorkerAdminHandler) TriggerWorkerRun(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "trigger_worker_run")
defer span.End()
if h.worker != nil {
h.worker.TriggerManualRun()
c.JSON(http.StatusOK, gin.H{"message": "Worker run triggered"})
} else {
HandleAppError(c, contextutils.ErrServiceUnavailable)
}
}
// PauseWorkerUser pauses question generation for a specific user
func (h *WorkerAdminHandler) PauseWorkerUser(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "pause_user")
defer span.End()
var req struct {
UserID int `json:"user_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request",
"",
err,
))
return
}
if err := h.workerService.SetUserPause(ctx, req.UserID, true); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to pause user"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "User paused successfully"})
}
// ResumeWorkerUser resumes question generation for a specific user
func (h *WorkerAdminHandler) ResumeWorkerUser(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "resume_user")
defer span.End()
var req struct {
UserID int `json:"user_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request",
"",
err,
))
return
}
if err := h.workerService.SetUserPause(ctx, req.UserID, false); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to resume user"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "User resumed successfully"})
}
// GetWorkerUsers returns basic user list for worker controls
func (h *WorkerAdminHandler) GetWorkerUsers(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_users")
defer span.End()
users, err := h.userService.GetAllUsers(ctx)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get users"))
return
}
// Add pause status for each user
var userList []gin.H
for _, user := range users {
isPaused, _ := h.workerService.IsUserPaused(ctx, user.ID)
userList = append(userList, gin.H{
"id": user.ID,
"username": user.Username,
"is_paused": isPaused,
})
}
c.JSON(http.StatusOK, gin.H{"users": userList})
}
// GetSystemHealth returns comprehensive system health
func (h *WorkerAdminHandler) GetSystemHealth(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_system_health")
defer span.End()
health, err := h.workerService.GetWorkerHealth(ctx)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get system health"))
return
}
c.JSON(http.StatusOK, health)
}
// GetAIConcurrencyStats returns AI service concurrency metrics from the worker
func (h *WorkerAdminHandler) GetAIConcurrencyStats(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_concurrency_stats")
defer span.End()
if h.aiService == nil {
HandleAppError(c, contextutils.ErrAIProviderUnavailable)
return
}
stats := h.aiService.GetConcurrencyStats()
c.JSON(http.StatusOK, gin.H{
"ai_concurrency": stats,
})
}
// GetPriorityAnalytics returns priority system analytics
func (h *WorkerAdminHandler) GetPriorityAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_priority_analytics")
defer span.End()
// Get priority score distribution
distribution, err := h.learningService.GetPriorityScoreDistribution(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting priority score distribution", err, map[string]interface{}{})
distribution = map[string]interface{}{
"high": 0,
"medium": 0,
"low": 0,
"average": 0.0,
}
}
// Get high priority questions
highPriorityQuestions, err := h.learningService.GetHighPriorityQuestions(ctx, 5)
if err != nil {
h.logger.Error(ctx, "Error getting high priority questions", err, map[string]interface{}{})
highPriorityQuestions = []map[string]interface{}{}
}
response := gin.H{
"distribution": distribution,
"highPriorityQuestions": highPriorityQuestions,
}
c.JSON(http.StatusOK, response)
}
// GetUserPriorityAnalytics returns priority analytics for a specific user
func (h *WorkerAdminHandler) GetUserPriorityAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_priority_analytics")
defer span.End()
userIDStr := c.Param("userID")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Verify user exists
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil || user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Get user-specific priority score distribution
distribution, err := h.learningService.GetUserPriorityScoreDistribution(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Error getting user priority score distribution", err, map[string]interface{}{})
distribution = map[string]interface{}{
"high": 0,
"medium": 0,
"low": 0,
"average": 0.0,
}
}
// Get user's high priority questions
highPriorityQuestions, err := h.learningService.GetUserHighPriorityQuestions(ctx, userID, 10)
if err != nil {
h.logger.Error(ctx, "Error getting user high priority questions", err, map[string]interface{}{})
highPriorityQuestions = []map[string]interface{}{}
}
// Get user's weak areas
weakAreas, err := h.learningService.GetUserWeakAreas(ctx, userID, 5)
if err != nil {
h.logger.Error(ctx, "Error getting user weak areas", err, map[string]interface{}{})
weakAreas = []map[string]interface{}{}
}
// Get user's learning preferences
preferences, err := h.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Error getting user learning preferences", err, map[string]interface{}{})
preferences = nil
}
response := gin.H{
"user": gin.H{
"id": user.ID,
"username": user.Username,
},
"distribution": distribution,
"highPriorityQuestions": highPriorityQuestions,
"weakAreas": weakAreas,
"learningPreferences": preferences,
}
c.JSON(http.StatusOK, response)
}
// GetUserPerformanceAnalytics returns user performance analytics
func (h *WorkerAdminHandler) GetUserPerformanceAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_performance_analytics")
defer span.End()
// Get weak areas by topic
weakAreas, err := h.learningService.GetWeakAreasByTopic(ctx, 5)
if err != nil {
h.logger.Error(ctx, "Error getting weak areas", err, map[string]interface{}{})
weakAreas = []map[string]interface{}{}
}
// Get learning preferences usage
learningPreferences, err := h.learningService.GetLearningPreferencesUsage(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting learning preferences usage", err, map[string]interface{}{})
learningPreferences = map[string]interface{}{}
}
response := gin.H{
"weakAreas": weakAreas,
"learningPreferences": learningPreferences,
}
c.JSON(http.StatusOK, response)
}
// GetGenerationIntelligence returns question generation intelligence
func (h *WorkerAdminHandler) GetGenerationIntelligence(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_generation_intelligence")
defer span.End()
// Get gap analysis
gapAnalysis, err := h.learningService.GetQuestionTypeGaps(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting gap analysis", err, map[string]interface{}{})
gapAnalysis = []map[string]interface{}{}
}
// Get generation suggestions
generationSuggestions, err := h.learningService.GetGenerationSuggestions(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting generation suggestions", err, map[string]interface{}{})
generationSuggestions = []map[string]interface{}{}
}
// Ensure we always return arrays, not nil
if gapAnalysis == nil {
gapAnalysis = []map[string]interface{}{}
}
if generationSuggestions == nil {
generationSuggestions = []map[string]interface{}{}
}
response := gin.H{
"gapAnalysis": gapAnalysis,
"generationSuggestions": generationSuggestions,
}
c.JSON(http.StatusOK, response)
}
// GetSystemHealthAnalytics returns system health analytics
func (h *WorkerAdminHandler) GetSystemHealthAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_system_health_analytics")
defer span.End()
// Get performance metrics
performance, err := h.learningService.GetPrioritySystemPerformance(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting performance metrics", err, map[string]interface{}{})
performance = map[string]interface{}{}
}
// Get background jobs status
backgroundJobs, err := h.learningService.GetBackgroundJobsStatus(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting background jobs status", err, map[string]interface{}{})
backgroundJobs = map[string]interface{}{}
}
response := gin.H{
"performance": performance,
"backgroundJobs": backgroundJobs,
}
c.JSON(http.StatusOK, response)
}
// GetUserComparisonAnalytics returns comparison analytics between users
func (h *WorkerAdminHandler) GetUserComparisonAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_comparison_analytics")
defer span.End()
userIDsParam := c.Query("user_ids")
if userIDsParam == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Split comma-separated user IDs
userIDsStr := strings.Split(userIDsParam, ",")
if len(userIDsStr) == 0 {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
var userIDs []int
for _, idStr := range userIDsStr {
idStr = strings.TrimSpace(idStr) // Remove whitespace
if idStr == "" {
continue
}
id, err := strconv.Atoi(idStr)
if err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidFormat,
contextutils.SeverityWarn,
"Invalid user ID",
idStr,
err,
))
return
}
userIDs = append(userIDs, id)
}
if len(userIDs) == 0 {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Get comparison data for each user
var comparisonData []gin.H
for _, userID := range userIDs {
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
continue // Skip invalid users
}
distribution, _ := h.learningService.GetUserPriorityScoreDistribution(ctx, userID)
weakAreas, _ := h.learningService.GetUserWeakAreas(ctx, userID, 3)
userData := gin.H{
"user": gin.H{
"id": user.ID,
"username": user.Username,
},
"distribution": distribution,
"weakAreas": weakAreas,
}
comparisonData = append(comparisonData, userData)
}
c.JSON(http.StatusOK, gin.H{"comparison": comparisonData})
}
// GetConfigz returns the merged config as pretty-printed JSON
func (h *WorkerAdminHandler) GetConfigz(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_configz")
defer span.End()
c.IndentedJSON(http.StatusOK, h.config)
}
// GetNotificationStats returns comprehensive notification statistics
func (h *WorkerAdminHandler) GetNotificationStats(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_notification_stats")
defer span.End()
// Get notification statistics from database
stats, err := h.workerService.GetNotificationStats(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to get notification stats", err, nil)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get notification statistics",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, stats)
}
// GetNotificationErrors returns paginated notification errors
func (h *WorkerAdminHandler) GetNotificationErrors(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_notification_errors")
defer span.End()
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "error_type", "notification_type", "resolved")
errorType := f["error_type"]
notificationType := f["notification_type"]
resolved := f["resolved"]
// Get notification errors from database
errors, pagination, stats, err := h.workerService.GetNotificationErrors(ctx, page, pageSize, errorType, notificationType, resolved)
if err != nil {
h.logger.Error(ctx, "Failed to get notification errors", err, nil)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get notification errors",
"details": err.Error(),
})
return
}
WritePaginated(c, "errors", errors, pagination, gin.H{"stats": stats})
}
// GetSentNotifications returns paginated sent notifications
func (h *WorkerAdminHandler) GetSentNotifications(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_sent_notifications")
defer span.End()
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "notification_type", "status", "sent_after", "sent_before")
notificationType := f["notification_type"]
status := f["status"]
sentAfter := f["sent_after"]
sentBefore := f["sent_before"]
// Get sent notifications from database
notifications, pagination, stats, err := h.workerService.GetSentNotifications(ctx, page, pageSize, notificationType, status, sentAfter, sentBefore)
if err != nil {
h.logger.Error(ctx, "Failed to get sent notifications", err, nil)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get sent notifications",
"details": err.Error(),
})
return
}
WritePaginated(c, "notifications", notifications, pagination, gin.H{"stats": stats})
}
// CreateTestSentNotification creates a test sent notification for testing
func (h *WorkerAdminHandler) CreateTestSentNotification(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "create_test_sent_notification")
defer span.End()
// Parse request body
var request struct {
UserID int `json:"user_id" binding:"required"`
NotificationType string `json:"notification_type" binding:"required"`
Subject string `json:"subject" binding:"required"`
TemplateName string `json:"template_name" binding:"required"`
Status string `json:"status" binding:"required"`
ErrorMessage string `json:"error_message"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Create test notification
err := h.workerService.CreateTestSentNotification(ctx, request.UserID, request.NotificationType, request.Subject, request.TemplateName, request.Status, request.ErrorMessage)
if err != nil {
h.logger.Error(ctx, "Failed to create test sent notification", err, map[string]interface{}{
"user_id": request.UserID,
"notification_type": request.NotificationType,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to create test sent notification",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Test sent notification created successfully"})
}
// ForceSendNotification forces sending a notification to a user, bypassing normal checks
func (h *WorkerAdminHandler) ForceSendNotification(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "force_send_notification")
defer span.End()
// Parse request body
var request struct {
Username string `json:"username" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Get user by username
user, err := h.userService.GetUserByUsername(ctx, request.Username)
if err != nil {
h.logger.Error(ctx, "Failed to get user by username", err, map[string]interface{}{
"username": request.Username,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user",
"details": err.Error(),
})
return
}
if user == nil {
HandleAppError(c, contextutils.NewAppError(
contextutils.ErrorCodeRecordNotFound,
contextutils.SeverityInfo,
fmt.Sprintf("User '%s' not found", request.Username),
"",
))
return
}
// Check if user has email address
if !user.Email.Valid || user.Email.String == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Get user's learning preferences to check daily reminder setting
prefs, err := h.learningService.GetUserLearningPreferences(ctx, user.ID)
if err != nil {
h.logger.Error(ctx, "Failed to get user learning preferences", err, map[string]interface{}{
"user_id": user.ID,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user preferences",
"details": err.Error(),
})
return
}
// Check if daily reminders are enabled for this user
if prefs == nil || !prefs.DailyReminderEnabled {
HandleAppError(c, contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityWarn, "User has daily reminders disabled", ""))
return
}
// Force send the daily reminder (bypassing time and date checks)
subject := "Time for your daily quiz! ð"
status := "sent"
errorMsg := ""
// Get email service from worker
emailService := h.worker.GetEmailService()
if emailService == nil {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
// Send the email
if err := emailService.SendDailyReminder(ctx, user); err != nil {
h.logger.Error(ctx, "Failed to send forced daily reminder", err, map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
HandleAppError(c, contextutils.WrapError(err, "failed to send notification"))
return
}
// Record the sent notification in the database
if err := emailService.RecordSentNotification(ctx, user.ID, "daily_reminder", subject, "daily_reminder", status, errorMsg); err != nil {
h.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": user.ID,
})
// Don't fail the request if recording fails
}
// Update the last reminder sent timestamp for this user
if err := h.learningService.UpdateLastDailyReminderSent(ctx, user.ID); err != nil {
h.logger.Error(ctx, "Failed to update last daily reminder sent timestamp", err, map[string]interface{}{
"user_id": user.ID,
})
// Don't fail the request if timestamp update fails
}
h.logger.Info(ctx, "Forced notification sent successfully", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email.String,
})
c.JSON(http.StatusOK, gin.H{
"message": "Notification sent successfully",
"user": gin.H{
"id": user.ID,
"username": user.Username,
"email": user.Email.String,
},
"notification": gin.H{
"type": "daily_reminder",
"subject": subject,
"status": status,
},
})
}
// GetUserDailyQuestions returns daily questions for a specific user and date (admin only)
func (h *WorkerAdminHandler) GetUserDailyQuestions(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "admin_get_user_daily_questions")
defer span.End()
// Parse user ID
userIDStr := c.Param("userId")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for daily questions", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Parse date
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
date, err := time.Parse("2006-01-02", dateStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
)
// Get daily questions for the user and date
questions, err := h.dailyQuestionService.GetDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get user daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get daily questions",
"details": err.Error(),
})
return
}
// Convert to API format (similar to the daily question handler)
apiQuestions := make([]gin.H, len(questions))
for i, q := range questions {
var completedAt *time.Time
if q.CompletedAt.Valid {
completedAt = &q.CompletedAt.Time
}
apiQuestions[i] = gin.H{
"id": q.ID,
"user_id": q.UserID,
"question_id": q.QuestionID,
"assignment_date": q.AssignmentDate,
"is_completed": q.IsCompleted,
"completed_at": completedAt,
"created_at": q.CreatedAt,
// Per-user stats for admin UI
"user_shown_count": q.DailyShownCount,
"user_total_responses": q.UserTotalResponses,
"user_correct_count": q.UserCorrectCount,
"user_incorrect_count": q.UserIncorrectCount,
"question": gin.H{
"id": q.Question.ID,
"type": q.Question.Type,
"language": q.Question.Language,
"level": q.Question.Level,
"difficulty_score": q.Question.DifficultyScore,
"content": q.Question.Content,
"correct_answer": q.Question.CorrectAnswer,
"explanation": q.Question.Explanation,
"created_at": q.Question.CreatedAt,
"status": q.Question.Status,
"topic_category": q.Question.TopicCategory,
"grammar_focus": q.Question.GrammarFocus,
"vocabulary_domain": q.Question.VocabularyDomain,
"scenario": q.Question.Scenario,
"style_modifier": q.Question.StyleModifier,
"difficulty_modifier": q.Question.DifficultyModifier,
"time_context": q.Question.TimeContext,
},
}
}
c.JSON(http.StatusOK, gin.H{"questions": apiQuestions})
}
// RegenerateUserDailyQuestions clears and regenerates daily questions for a specific user and date (admin only)
func (h *WorkerAdminHandler) RegenerateUserDailyQuestions(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "admin_regenerate_user_daily_questions")
defer span.End()
// Parse user ID
userIDStr := c.Param("userId")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for daily questions regeneration", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Parse date
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
date, err := time.Parse("2006-01-02", dateStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
)
// For regeneration, we need to manually clear existing assignments and create new ones
// Since the daily question service doesn't expose a direct way to clear assignments,
// we'll use the worker service which should have database access for this admin operation
// Check if worker service is available
if h.workerService == nil {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
// Use the new RegenerateDailyQuestions method which clears existing assignments and creates new ones
err = h.dailyQuestionService.RegenerateDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to regenerate daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
// If there are no questions available for assignment, prefer the structured error from the service
var nqErr *services.NoQuestionsAvailableError
if errors.As(err, &nqErr) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Failed to regenerate daily questions",
"details": err.Error(),
"user": gin.H{"id": user.ID, "username": user.Username, "language": nqErr.Language, "level": nqErr.Level},
"candidate_count": nqErr.CandidateCount,
"candidate_ids": nqErr.CandidateIDs,
"total_matching_questions": nqErr.TotalMatching,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to regenerate daily questions",
"details": err.Error(),
})
return
}
h.logger.Info(ctx, "Daily questions regenerated successfully", map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Daily questions regenerated successfully. All existing assignments have been cleared and new questions assigned."})
}
// Package middleware provides authentication and authorization middleware for the Gin web framework.
package middleware
import (
"context"
"net/http"
"strings"
"quizapp/internal/models"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// Session keys for storing user information
const (
// UserIDKey is the key used to store user ID in session
UserIDKey = "user_id"
// UsernameKey is the key used to store username in session
UsernameKey = "username"
// AuthMethodKey is the key used to store authentication method
AuthMethodKey = "auth_method"
// APIKeyIDKey is the key used to store API key ID (for API key auth)
APIKeyIDKey = "api_key_id"
)
// AuthMethod constants
const (
AuthMethodSession = "session"
AuthMethodAPIKey = "api_key"
)
// AuthAPIKeyValidator is an interface for validating API keys
type AuthAPIKeyValidator interface {
ValidateAPIKey(ctx context.Context, rawKey string) (*models.AuthAPIKey, error)
UpdateLastUsed(ctx context.Context, keyID int) error
}
// AuthUserServiceGetter is an interface for getting user info
type AuthUserServiceGetter interface {
GetUserByID(ctx context.Context, userID int) (*models.User, error)
}
// RequireAuth returns a middleware that requires authentication
// This version only supports session-based auth for backward compatibility
func RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
// Fall back to session authentication
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Validate user_id is an integer
userIDInt, ok := userID.(int)
if !ok {
// Try to convert from float64 (JSON numbers are often stored as float64)
if userIDFloat, ok := userID.(float64); ok {
userIDInt = int(userIDFloat)
} else {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
}
// Validate username is a string and not empty
username := session.Get(UsernameKey)
if username == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
usernameStr, ok := username.(string)
if !ok || usernameStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Store user info in context for handlers to use
c.Set(UserIDKey, userIDInt)
c.Set(UsernameKey, usernameStr)
c.Set(AuthMethodKey, AuthMethodSession)
c.Next()
}
}
// RequireAuthWithAPIKey returns a middleware that requires authentication via API key or session
// It checks for API key authentication first, then falls back to session authentication
func RequireAuthWithAPIKey(apiKeyService AuthAPIKeyValidator, userService AuthUserServiceGetter) gin.HandlerFunc {
return func(c *gin.Context) {
// Check for API key authentication first
var rawKey string
authHeader := c.GetHeader("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
rawKey = strings.TrimPrefix(authHeader, "Bearer ")
} else {
// Check for API key in query parameter
rawKey = c.Query("api_key")
}
if rawKey != "" {
// Validate API key
apiKey, err := apiKeyService.ValidateAPIKey(c.Request.Context(), rawKey)
if err == nil && apiKey != nil {
// Check permission level against request method
if !apiKey.CanPerformMethod(c.Request.Method) {
c.JSON(http.StatusForbidden, gin.H{
"error": "This API key does not have permission for this operation",
"code": "FORBIDDEN",
})
c.Abort()
return
}
// Get user info to set username in context
user, err := userService.GetUserByID(c.Request.Context(), apiKey.UserID)
if err != nil || user == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid API key - user not found",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Set user context
c.Set(UserIDKey, apiKey.UserID)
c.Set(UsernameKey, user.Username)
c.Set(AuthMethodKey, AuthMethodAPIKey)
c.Set(APIKeyIDKey, apiKey.ID)
// Update last used timestamp asynchronously
go func() {
_ = apiKeyService.UpdateLastUsed(context.Background(), apiKey.ID)
}()
c.Next()
return
}
// If we got here with a key (from header or query), it's invalid
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid API key",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Fall back to session authentication
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Validate user_id is an integer
userIDInt, ok := userID.(int)
if !ok {
// Try to convert from float64 (JSON numbers are often stored as float64)
if userIDFloat, ok := userID.(float64); ok {
userIDInt = int(userIDFloat)
} else {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
}
// Validate username is a string and not empty
username := session.Get(UsernameKey)
if username == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
usernameStr, ok := username.(string)
if !ok || usernameStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Store user info in context for handlers to use
c.Set(UserIDKey, userIDInt)
c.Set(UsernameKey, usernameStr)
c.Set(AuthMethodKey, AuthMethodSession)
c.Next()
}
}
// RequireAdmin returns a middleware that requires authentication and admin role
func RequireAdmin(userService interface{}) gin.HandlerFunc {
// Type assertion to get the user service
us, ok := userService.(interface {
IsAdmin(ctx context.Context, userID int) (bool, error)
})
if !ok {
panic("userService must implement IsAdmin method")
}
return func(c *gin.Context) {
// First check authentication
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Validate user_id is an integer
userIDInt, ok := userID.(int)
if !ok {
// Try to convert from float64 (JSON numbers are often stored as float64)
if userIDFloat, ok := userID.(float64); ok {
userIDInt = int(userIDFloat)
} else {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
}
// Validate username is a string and not empty
username := session.Get(UsernameKey)
if username == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
usernameStr, ok := username.(string)
if !ok || usernameStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Check if user has admin role
isAdmin, err := us.IsAdmin(c.Request.Context(), userIDInt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to check admin status",
"code": "INTERNAL_ERROR",
})
c.Abort()
return
}
if !isAdmin {
c.JSON(http.StatusForbidden, gin.H{
"error": "Admin access required",
"code": "FORBIDDEN",
})
c.Abort()
return
}
// Store user info in context for handlers to use
c.Set(UserIDKey, userIDInt)
c.Set(UsernameKey, usernameStr)
c.Next()
}
}
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"time"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// ErrorRecoveryConfig configures error recovery behavior
type ErrorRecoveryConfig struct {
// MaxRetries specifies the maximum number of retries for retryable errors
MaxRetries int
// RetryDelay specifies the base delay between retries
RetryDelay time.Duration
// MaxRetryDelay specifies the maximum delay between retries
MaxRetryDelay time.Duration
// EnableCircuitBreaker enables circuit breaker pattern
EnableCircuitBreaker bool
// CircuitBreakerThreshold specifies failure threshold for circuit breaker
CircuitBreakerThreshold int
// CircuitBreakerTimeout specifies how long to wait before retrying after circuit opens
CircuitBreakerTimeout time.Duration
}
// DefaultErrorRecoveryConfig returns a default error recovery configuration
func DefaultErrorRecoveryConfig() *ErrorRecoveryConfig {
return &ErrorRecoveryConfig{
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
MaxRetryDelay: 5 * time.Second,
EnableCircuitBreaker: false,
CircuitBreakerThreshold: 5,
CircuitBreakerTimeout: 30 * time.Second,
}
}
// circuitBreakerState represents the state of a circuit breaker
type circuitBreakerState int
const (
circuitClosed circuitBreakerState = iota
circuitOpen
circuitHalfOpen
)
// circuitBreaker tracks failures and manages circuit state
type circuitBreaker struct {
state circuitBreakerState
failures int
lastFailure time.Time
config *ErrorRecoveryConfig
}
// newCircuitBreaker creates a new circuit breaker
func newCircuitBreaker(config *ErrorRecoveryConfig) *circuitBreaker {
return &circuitBreaker{
state: circuitClosed,
config: config,
}
}
// canExecute checks if the circuit breaker allows execution
func (cb *circuitBreaker) canExecute() bool {
switch cb.state {
case circuitClosed:
return true
case circuitOpen:
if time.Since(cb.lastFailure) > cb.config.CircuitBreakerTimeout {
cb.state = circuitHalfOpen
return true
}
return false
case circuitHalfOpen:
return true
default:
return false
}
}
// recordSuccess records a successful execution
func (cb *circuitBreaker) recordSuccess() {
cb.failures = 0
cb.state = circuitClosed
}
// recordFailure records a failed execution
func (cb *circuitBreaker) recordFailure() {
cb.failures++
cb.lastFailure = time.Now()
if cb.failures >= cb.config.CircuitBreakerThreshold {
cb.state = circuitOpen
}
}
// ErrorRecoveryMiddleware creates middleware for handling panics and retrying failed requests
func ErrorRecoveryMiddleware(logger interface{}, config *ErrorRecoveryConfig) gin.HandlerFunc {
if config == nil {
config = DefaultErrorRecoveryConfig()
}
// Create circuit breaker if enabled
var cb *circuitBreaker
if config.EnableCircuitBreaker {
cb = newCircuitBreaker(config)
}
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// Log the panic with stack trace
stackTrace := string(debug.Stack())
fmt.Printf("Panic recovered: %v\nStack trace: %s\n", err, stackTrace)
// Convert panic value to error if needed
var panicErr error
if e, ok := err.(error); ok {
panicErr = e
} else {
panicErr = contextutils.WrapErrorf(nil, "panic: %v", err)
}
// Send error response
appErr := contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInternalError,
contextutils.SeverityFatal,
"Internal server error",
"A panic occurred while processing the request",
contextutils.WrapError(panicErr, "panic"),
)
// Add stack trace to error details in development
if gin.Mode() == gin.DebugMode {
appErr.Details = fmt.Sprintf("%s\nStack trace: %s", appErr.Details, stackTrace)
}
HandleAppError(c, appErr)
c.Abort()
}
}()
// Check circuit breaker
if cb != nil && !cb.canExecute() {
ServiceUnavailable(c, "Service temporarily unavailable due to high error rate")
c.Abort()
return
}
// Process request
c.Next()
// Record success/failure for circuit breaker
if cb != nil {
if c.Writer.Status() >= 500 {
cb.recordFailure()
} else if c.Writer.Status() < 500 && cb.state == circuitHalfOpen {
cb.recordSuccess()
}
}
// Retry logic for failed requests
if shouldRetry(c.Writer.Status(), c.Errors) {
retryWithBackoff(c, config, logger)
}
}
}
// shouldRetry determines if a request should be retried
func shouldRetry(statusCode int, errors []*gin.Error) bool {
// Only retry 5xx errors and certain 4xx errors
if statusCode >= 500 {
return true
}
// Retry on specific 4xx errors that might be transient
if statusCode == http.StatusRequestTimeout || statusCode == http.StatusTooManyRequests {
return true
}
// Check if there are errors that indicate retryable failures
for _, err := range errors {
if contextutils.IsRetryable(err) {
return true
}
}
return false
}
// retryWithBackoff attempts to retry the request with exponential backoff
func retryWithBackoff(c *gin.Context, config *ErrorRecoveryConfig, logger interface{}) {
// Only retry idempotent methods (GET, HEAD, OPTIONS, PUT, DELETE)
method := c.Request.Method
if method != http.MethodGet && method != http.MethodHead &&
method != http.MethodOptions && method != http.MethodPut &&
method != http.MethodDelete {
return
}
// Get the original handler
handlerName := c.HandlerName()
if handlerName == "" {
return
}
// Calculate retry delay with exponential backoff
delay := config.RetryDelay
for i := 0; i < config.MaxRetries; i++ {
time.Sleep(delay)
// Double the delay for next iteration (with max limit)
delay *= 2
if delay > config.MaxRetryDelay {
delay = config.MaxRetryDelay
}
// Log retry attempt
if logger != nil {
// This would be logged using the observability logger in real implementation
fmt.Printf("Retrying request %s %s (attempt %d/%d)\n",
method, c.Request.URL.Path, i+1, config.MaxRetries)
}
// Note: In a real implementation, we would need to recreate the request
// and re-execute it. This is a simplified version for demonstration.
// The actual retry logic would depend on the specific use case.
}
}
// HandleAppError handles any AppError and sends appropriate HTTP response
func HandleAppError(c *gin.Context, err error) {
if appErr, ok := err.(*contextutils.AppError); ok {
StandardizeAppError(c, appErr)
} else {
// Fallback for non-AppError types
StandardizeHTTPError(c, http.StatusInternalServerError, "Internal server error", err.Error())
}
}
// StandardizeAppError sends a structured error response using AppError
func StandardizeAppError(c *gin.Context, err *contextutils.AppError) {
// Map error codes to HTTP status codes
statusCode := mapErrorCodeToHTTPStatus(err.Code)
// Convert error to JSON structure
errorJSON := err.ToJSON()
// Add retryable information based on error type
errorJSON["retryable"] = contextutils.IsRetryable(err)
c.JSON(statusCode, errorJSON)
}
// StandardizeHTTPError creates consistent HTTP error responses with structured error information
func StandardizeHTTPError(c *gin.Context, _ int, message, details string) {
// Create a generic AppError for consistent response format
appErr := contextutils.NewAppError(
contextutils.ErrorCodeInternalError,
contextutils.SeverityError,
message,
details,
)
StandardizeAppError(c, appErr)
}
// ServiceUnavailable sends a 503 Service Unavailable error with a standardized payload
func ServiceUnavailable(c *gin.Context, msg string) {
appErr := contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
msg,
"",
)
StandardizeAppError(c, appErr)
}
// mapErrorCodeToHTTPStatus maps AppError codes to appropriate HTTP status codes
func mapErrorCodeToHTTPStatus(code contextutils.ErrorCode) int {
switch code {
// 4xx Client Errors
case contextutils.ErrorCodeInvalidInput, contextutils.ErrorCodeMissingRequired,
contextutils.ErrorCodeInvalidFormat, contextutils.ErrorCodeValidationFailed,
contextutils.ErrorCodeOAuthStateMismatch:
return http.StatusBadRequest
case contextutils.ErrorCodeUnauthorized:
return http.StatusUnauthorized
case contextutils.ErrorCodeForbidden:
return http.StatusForbidden
case contextutils.ErrorCodeRecordNotFound, contextutils.ErrorCodeQuestionNotFound,
contextutils.ErrorCodeAssignmentNotFound:
return http.StatusNotFound
case contextutils.ErrorCodeRecordExists:
return http.StatusConflict
case contextutils.ErrorCodeSessionExpired, contextutils.ErrorCodeInvalidCredentials:
return http.StatusUnauthorized
case contextutils.ErrorCodeRateLimit:
return http.StatusTooManyRequests
// 5xx Server Errors
case contextutils.ErrorCodeInternalError:
return http.StatusInternalServerError
case contextutils.ErrorCodeServiceUnavailable, contextutils.ErrorCodeDatabaseConnection,
contextutils.ErrorCodeAIProviderUnavailable:
return http.StatusServiceUnavailable
case contextutils.ErrorCodeTimeout:
return http.StatusRequestTimeout
case contextutils.ErrorCodeDatabaseQuery, contextutils.ErrorCodeDatabaseTransaction,
contextutils.ErrorCodeForeignKeyViolation, contextutils.ErrorCodeTimestampMissingTimezone,
contextutils.ErrorCodeAIRequestFailed, contextutils.ErrorCodeAIResponseInvalid,
contextutils.ErrorCodeAIConfigInvalid, contextutils.ErrorCodeOAuthProviderError:
return http.StatusInternalServerError
// Default to internal server error for unknown codes
default:
return http.StatusInternalServerError
}
}
package middleware
import (
"encoding/json"
"fmt"
"os"
"strings"
contextutils "quizapp/internal/utils"
"github.com/xeipuuv/gojsonschema"
"gopkg.in/yaml.v2"
)
// SchemaLoader loads JSON schemas from the Swagger specification
type SchemaLoader struct {
schemas map[string]*gojsonschema.Schema
jsonCompatibleSchemas map[string]interface{}
swaggerData map[string]interface{}
}
// NewSchemaLoader creates a new schema loader
func NewSchemaLoader() *SchemaLoader {
return &SchemaLoader{
schemas: make(map[string]*gojsonschema.Schema),
jsonCompatibleSchemas: make(map[string]interface{}),
}
}
// LoadSchemasFromSwagger loads all schemas from the Swagger specification
func (sl *SchemaLoader) LoadSchemasFromSwagger(swaggerPath string) error {
// Read the Swagger file
data, err := os.ReadFile(swaggerPath)
if err != nil {
return contextutils.WrapError(err, "failed to read swagger file")
}
return sl.LoadSchemasFromSwaggerFromData(data)
}
// LoadSchemasFromSwaggerFromData loads all schemas from swagger data bytes
func (sl *SchemaLoader) LoadSchemasFromSwaggerFromData(data []byte) error {
// Parse the Swagger spec (YAML only)
var swagger map[string]interface{}
if err := yaml.Unmarshal(data, &swagger); err != nil {
return contextutils.WrapError(err, "failed to parse swagger file as YAML")
}
fmt.Printf("â Successfully parsed swagger file as YAML\n")
// Store the parsed swagger data for later use
sl.swaggerData = swagger
// Extract components/schemas
components, ok := swagger["components"].(map[interface{}]interface{})
if !ok {
fmt.Printf("â No components section found. Available keys: %v\n", getKeys(swagger))
fmt.Printf("â Components type: %T, value: %v\n", swagger["components"], swagger["components"])
return contextutils.ErrorWithContextf("no components section found in swagger")
}
schemas, ok := components["schemas"].(map[interface{}]interface{})
if !ok {
fmt.Printf("â No schemas section found in components. Available keys: %v\n", getKeysInterface(components))
fmt.Printf("â Schemas type: %T, value: %v\n", components["schemas"], components["schemas"])
return contextutils.ErrorWithContextf("no schemas section found in swagger")
}
// Convert schemas to JSON-compatible format
jsonCompatibleSchemas := make(map[string]interface{})
for schemaName, schemaData := range schemas {
schemaNameStr, ok := schemaName.(string)
if !ok {
fmt.Printf("Warning: schema name is not a string: %v\n", schemaName)
continue
}
convertedSchema := convertToJSONCompatible(schemaData)
jsonCompatibleSchemas[schemaNameStr] = convertedSchema
}
// Store jsonCompatibleSchemas for creating array schemas later
sl.jsonCompatibleSchemas = jsonCompatibleSchemas
// Load each schema
for schemaNameStr := range jsonCompatibleSchemas {
// Create a schema document with the full swagger context for $ref resolution
completeSchemaDoc := map[string]interface{}{
"$schema": "http://json-schema.org/draft-07/schema#",
"components": map[string]interface{}{
"schemas": jsonCompatibleSchemas,
},
"$ref": "#/components/schemas/" + schemaNameStr,
}
schemaBytes, err := json.Marshal(completeSchemaDoc)
if err != nil {
fmt.Printf("Warning: failed to marshal schema %s: %v\n", schemaNameStr, err)
continue
}
// Load the schema
schemaLoader := gojsonschema.NewBytesLoader(schemaBytes)
schema, err := gojsonschema.NewSchema(schemaLoader)
if err != nil {
fmt.Printf("Warning: failed to load schema %s: %v\n", schemaNameStr, err)
continue
}
sl.schemas[schemaNameStr] = schema
fmt.Printf("â Loaded schema: %s\n", schemaNameStr)
}
return nil
}
// getKeys returns the keys of a map
func getKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// getKeysInterface returns the keys of a map with interface{} keys
func getKeysInterface(m map[interface{}]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
if keyStr, ok := k.(string); ok {
keys = append(keys, keyStr)
}
}
return keys
}
// convertInterfaceMapToStringMap converts a map[interface{}]interface{} to map[string]interface{}
func convertInterfaceMapToStringMap(m map[interface{}]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for k, v := range m {
keyStr := fmt.Sprintf("%v", k) // Convert any key type to string
result[keyStr] = convertToJSONCompatible(v)
}
return result
}
// convertToJSONCompatible converts a map[interface{}]interface{} to map[string]interface{}
func convertToJSONCompatible(data interface{}) interface{} {
switch v := data.(type) {
case map[interface{}]interface{}:
result := make(map[string]interface{})
hasNullable := false
for k, val := range v {
keyStr := fmt.Sprintf("%v", k) // Convert any key type to string
// Check for nullable field
if keyStr == "nullable" {
nullable, ok := val.(bool)
if ok && nullable {
hasNullable = true
continue // Skip the nullable field as we'll handle it in the type conversion
}
}
convertedVal := convertToJSONCompatible(val)
result[keyStr] = convertedVal
}
// Handle nullable fields by converting to union type
if hasNullable {
// If there's a $ref field, create a union type with null
if ref, hasRef := result["$ref"].(string); hasRef {
// Create a union type that allows both the referenced type and null
result["oneOf"] = []interface{}{
map[string]interface{}{"$ref": ref},
map[string]interface{}{"enum": []interface{}{nil}},
}
// Remove the original $ref field
delete(result, "$ref")
} else if typeVal, hasType := result["type"].(string); hasType {
// If there's a type field, convert to array of types including null
result["type"] = []interface{}{typeVal, "null"}
}
}
return result
case []interface{}:
result := make([]interface{}, len(v))
for i, val := range v {
convertedVal := convertToJSONCompatible(val)
result[i] = convertedVal
}
return result
default:
return data
}
}
// ValidateData validates data against a schema
func (sl *SchemaLoader) ValidateData(data interface{}, schemaName string) error {
schema, exists := sl.schemas[schemaName]
if !exists {
return contextutils.ErrorWithContextf("schema %s not found", schemaName)
}
// Convert data to JSON
jsonData, err := json.Marshal(data)
if err != nil {
return contextutils.WrapError(err, "failed to marshal data")
}
// Create document loader
documentLoader := gojsonschema.NewBytesLoader(jsonData)
// Validate
result, err := schema.Validate(documentLoader)
if err != nil {
return contextutils.WrapError(err, "validation error")
}
if !result.Valid() {
var validationErrors []string
for _, validationErr := range result.Errors() {
errorMsg := fmt.Sprintf("%s: %s", validationErr.Field(), validationErr.Description())
// Include the actual value that failed validation if available
if validationErr.Value() != nil {
errorMsg += fmt.Sprintf(" (received: %v)", validationErr.Value())
}
validationErrors = append(validationErrors, errorMsg)
}
return contextutils.ErrorWithContextf("schema validation failed: %s", strings.Join(validationErrors, "; "))
}
return nil
}
// AutoLoadSchemas automatically loads schemas from the swagger file path
func AutoLoadSchemas() *SchemaLoader {
loader := NewSchemaLoader()
// Get swagger file path from environment variable
swaggerPath := os.Getenv("SWAGGER_FILE_PATH")
if swaggerPath == "" {
fmt.Printf("â SWAGGER_FILE_PATH environment variable not set\n")
return loader
}
if _, err := os.Stat(swaggerPath); err == nil {
if err := loader.LoadSchemasFromSwagger(swaggerPath); err != nil {
fmt.Printf("Warning: failed to load schemas from %s: %v\n", swaggerPath, err)
} else {
fmt.Printf("â Successfully loaded schemas from %s\n", swaggerPath)
return loader
}
} else {
fmt.Printf("âï Swagger file not found at %s: %v\n", swaggerPath, err)
}
return loader
}
// IsEndpointDocumented checks if an endpoint is documented in the swagger spec
func (sl *SchemaLoader) IsEndpointDocumented(path, method string) bool {
// Use cached swagger data if available
if sl.swaggerData == nil {
return false
}
swagger := sl.swaggerData
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return false
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// First, try exact match
pathInfo, exists := paths[path]
if exists {
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return false
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
_, exists = pathMap[strings.ToLower(method)]
if exists {
return true
}
}
// If exact match fails, try pattern matching for path parameters
for swaggerPath := range paths {
if sl.pathMatchesPattern(path, swaggerPath) {
pathInfo := paths[swaggerPath]
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
continue
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
_, exists = pathMap[strings.ToLower(method)]
if exists {
return true
}
}
}
return false
}
// pathMatchesPattern checks if a request path matches a swagger path pattern
func (sl *SchemaLoader) pathMatchesPattern(requestPath, swaggerPath string) bool {
// Split paths into segments
requestSegments := strings.Split(requestPath, "/")
swaggerSegments := strings.Split(swaggerPath, "/")
// Paths must have the same number of segments
if len(requestSegments) != len(swaggerSegments) {
return false
}
// Compare each segment
for i, swaggerSegment := range swaggerSegments {
requestSegment := requestSegments[i]
// If swagger segment is a parameter (starts with { and ends with })
if strings.HasPrefix(swaggerSegment, "{") && strings.HasSuffix(swaggerSegment, "}") {
// Any value is acceptable for parameters
continue
}
// Otherwise, segments must match exactly
if swaggerSegment != requestSegment {
return false
}
}
return true
}
// DetermineRequestSchemaFromPath automatically determines the schema name from the API path and method
func (sl *SchemaLoader) DetermineRequestSchemaFromPath(path, method string) string {
// Use cached swagger data if available
if sl.swaggerData == nil {
return ""
}
swagger := sl.swaggerData
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// First, try exact match
pathInfo, exists := paths[path]
if !exists {
// If exact match fails, try pattern matching for path parameters
for swaggerPath := range paths {
if sl.pathMatchesPattern(path, swaggerPath) {
pathInfo = paths[swaggerPath]
break
}
}
if pathInfo == nil {
return ""
}
}
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
methodInfo, exists := pathMap[strings.ToLower(method)]
if !exists {
return ""
}
methodMap, ok := methodInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
methodMapInterface, ok := methodInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
methodMap = convertInterfaceMapToStringMap(methodMapInterface)
}
// Extract the request body schema
requestBody, exists := methodMap["requestBody"]
if !exists {
return ""
}
requestBodyMap, ok := requestBody.(map[string]interface{})
if !ok {
// Try with interface{} keys
requestBodyMapInterface, ok := requestBody.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
requestBodyMap = convertInterfaceMapToStringMap(requestBodyMapInterface)
}
// Extract content
content, ok := requestBodyMap["content"].(map[string]interface{})
if !ok {
// Try with interface{} keys
contentInterface, ok := requestBodyMap["content"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
content = convertInterfaceMapToStringMap(contentInterface)
}
// Look for application/json content
jsonContent, exists := content["application/json"]
if !exists {
return ""
}
jsonContentMap, ok := jsonContent.(map[string]interface{})
if !ok {
// Try with interface{} keys
jsonContentMapInterface, ok := jsonContent.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
jsonContentMap = convertInterfaceMapToStringMap(jsonContentMapInterface)
}
// Extract schema
schema, exists := jsonContentMap["schema"]
if !exists {
return ""
}
schemaMap, ok := schema.(map[string]interface{})
if !ok {
// Try with interface{} keys
schemaMapInterface, ok := schema.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
schemaMap = convertInterfaceMapToStringMap(schemaMapInterface)
}
// Extract $ref
ref, exists := schemaMap["$ref"]
if !exists {
return ""
}
refStr, ok := ref.(string)
if !ok {
return ""
}
// Extract schema name from $ref
// $ref format: "#/components/schemas/SchemaName"
parts := strings.Split(refStr, "/")
if len(parts) < 4 {
return ""
}
return parts[len(parts)-1]
}
// DetermineSchemaFromPath determines the schema name for a given path and HTTP method
// by parsing the swagger file and looking up the response schema for the 200 status code.
func (sl *SchemaLoader) DetermineSchemaFromPath(path, method string) string {
// Use cached swagger data if available
if sl.swaggerData == nil {
return ""
}
swagger := sl.swaggerData
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// First, try exact match
pathInfo, exists := paths[path]
if !exists {
// If exact match fails, try pattern matching for path parameters
for swaggerPath := range paths {
if sl.pathMatchesPattern(path, swaggerPath) {
pathInfo = paths[swaggerPath]
break
}
}
if pathInfo == nil {
return ""
}
}
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
methodInfo, exists := pathMap[strings.ToLower(method)]
if !exists {
return ""
}
methodMap, ok := methodInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
methodMapInterface, ok := methodInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
methodMap = convertInterfaceMapToStringMap(methodMapInterface)
}
// Extract the response schema
responses, ok := methodMap["responses"].(map[string]interface{})
if !ok {
// Try with interface{} keys
responsesInterface, ok := methodMap["responses"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
responses = convertInterfaceMapToStringMap(responsesInterface)
}
// Look for success response (try 200, 201, etc.)
var successResponse interface{}
// Try common success status codes in order of preference
successCodes := []string{"200", "201", "202"}
for _, code := range successCodes {
if resp, exists := responses[code]; exists {
successResponse = resp
break
}
}
if successResponse == nil {
return ""
}
responseMap, ok := successResponse.(map[string]interface{})
if !ok {
// Try with interface{} keys
responseMapInterface, ok := successResponse.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
responseMap = convertInterfaceMapToStringMap(responseMapInterface)
}
// Extract content
content, ok := responseMap["content"].(map[string]interface{})
if !ok {
// Try with interface{} keys
contentInterface, ok := responseMap["content"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
content = convertInterfaceMapToStringMap(contentInterface)
}
// Look for application/json
jsonContent, exists := content["application/json"]
if !exists {
return ""
}
jsonMap, ok := jsonContent.(map[string]interface{})
if !ok {
// Try with interface{} keys
jsonMapInterface, ok := jsonContent.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
jsonMap = convertInterfaceMapToStringMap(jsonMapInterface)
}
// Extract schema reference
schema, exists := jsonMap["schema"]
if !exists {
return ""
}
schemaMap, ok := schema.(map[string]interface{})
if !ok {
// Try with interface{} keys
schemaMapInterface, ok := schema.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
schemaMap = convertInterfaceMapToStringMap(schemaMapInterface)
}
// Extract $ref directly
if ref, exists := schemaMap["$ref"]; exists {
if refStr, ok := ref.(string); ok {
// Extract schema name from $ref (e.g., "#/components/schemas/DashboardResponse")
if strings.HasPrefix(refStr, "#/components/schemas/") {
schemaName := strings.TrimPrefix(refStr, "#/components/schemas/")
return schemaName
}
}
}
// Handle array schemas - check if it's an array with items that have a $ref
if schemaType, exists := schemaMap["type"]; exists {
if typeStr, ok := schemaType.(string); ok && typeStr == "array" {
// Check for items.$ref
if items, exists := schemaMap["items"]; exists {
itemsMap, ok := items.(map[string]interface{})
if !ok {
// Try with interface{} keys
itemsMapInterface, ok := items.(map[interface{}]interface{})
if !ok {
return ""
}
itemsMap = convertInterfaceMapToStringMap(itemsMapInterface)
}
if ref, exists := itemsMap["$ref"]; exists {
if refStr, ok := ref.(string); ok {
// Extract schema name from $ref (e.g., "#/components/schemas/Story")
if strings.HasPrefix(refStr, "#/components/schemas/") {
itemSchemaName := strings.TrimPrefix(refStr, "#/components/schemas/")
// For array responses, we need to create a synthetic schema that validates arrays
arraySchemaName := fmt.Sprintf("%sArray", itemSchemaName)
// Check if we've already created this array schema
if _, exists := sl.schemas[arraySchemaName]; !exists {
// Create array schema with full context for $ref resolution
arraySchema := map[string]interface{}{
"$schema": "http://json-schema.org/draft-07/schema#",
"components": map[string]interface{}{
"schemas": sl.jsonCompatibleSchemas,
},
"type": "array",
"items": map[string]interface{}{
"$ref": fmt.Sprintf("#/components/schemas/%s", itemSchemaName),
},
}
// Load the array schema
schemaBytes, err := json.Marshal(arraySchema)
if err != nil {
fmt.Printf("Warning: failed to marshal array schema %s: %v\n", arraySchemaName, err)
return itemSchemaName // Fallback to item schema
}
schemaLoader := gojsonschema.NewBytesLoader(schemaBytes)
schema, err := gojsonschema.NewSchema(schemaLoader)
if err != nil {
fmt.Printf("Warning: failed to load array schema %s: %v\n", arraySchemaName, err)
return itemSchemaName // Fallback to item schema
}
sl.schemas[arraySchemaName] = schema
fmt.Printf("â Created array schema: %s\n", arraySchemaName)
}
return arraySchemaName
}
}
}
}
}
}
return ""
}
// DetermineResponseSchemaFromPath determines the schema name for a given path, method, and HTTP status code
func (sl *SchemaLoader) DetermineResponseSchemaFromPath(path, method, statusCode string) string {
// Use cached swagger data if available
if sl.swaggerData == nil {
return ""
}
swagger := sl.swaggerData
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// First, try exact match
pathInfo, exists := paths[path]
if !exists {
// If exact match fails, try pattern matching for path parameters
for swaggerPath := range paths {
if sl.pathMatchesPattern(path, swaggerPath) {
pathInfo = paths[swaggerPath]
break
}
}
if pathInfo == nil {
return ""
}
}
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
methodInfo, exists := pathMap[strings.ToLower(method)]
if !exists {
return ""
}
methodMap, ok := methodInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
methodMapInterface, ok := methodInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
methodMap = convertInterfaceMapToStringMap(methodMapInterface)
}
// Extract the response schema map
responses, ok := methodMap["responses"].(map[string]interface{})
if !ok {
// Try with interface{} keys
responsesInterface, ok := methodMap["responses"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
responses = convertInterfaceMapToStringMap(responsesInterface)
}
// Get response for the exact status code
successResponse, exists := responses[statusCode]
if !exists {
return ""
}
responseMap, ok := successResponse.(map[string]interface{})
if !ok {
// Try with interface{} keys
responseMapInterface, ok := successResponse.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
responseMap = convertInterfaceMapToStringMap(responseMapInterface)
}
// Extract content
content, ok := responseMap["content"].(map[string]interface{})
if !ok {
// Try with interface{} keys
contentInterface, ok := responseMap["content"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
content = convertInterfaceMapToStringMap(contentInterface)
}
// Look for application/json
jsonContent, exists := content["application/json"]
if !exists {
return ""
}
jsonMap, ok := jsonContent.(map[string]interface{})
if !ok {
// Try with interface{} keys
jsonMapInterface, ok := jsonContent.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
jsonMap = convertInterfaceMapToStringMap(jsonMapInterface)
}
// Extract schema reference
schema, exists := jsonMap["schema"]
if !exists {
return ""
}
schemaMap, ok := schema.(map[string]interface{})
if !ok {
// Try with interface{} keys
schemaMapInterface, ok := schema.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
schemaMap = convertInterfaceMapToStringMap(schemaMapInterface)
}
// Extract $ref directly
if ref, exists := schemaMap["$ref"]; exists {
if refStr, ok := ref.(string); ok {
if strings.HasPrefix(refStr, "#/components/schemas/") {
schemaName := strings.TrimPrefix(refStr, "#/components/schemas/")
return schemaName
}
}
}
return ""
}
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"quizapp/internal/observability"
"github.com/gin-gonic/gin"
)
// Global schema loader instance
var globalSchemaLoader *SchemaLoader
// initSchemaLoader initializes the global schema loader once
func initSchemaLoader() *SchemaLoader {
if globalSchemaLoader == nil {
globalSchemaLoader = AutoLoadSchemas()
}
return globalSchemaLoader
}
// ResponseValidationMiddleware creates middleware that automatically validates responses
func ResponseValidationMiddleware(logger *observability.Logger) gin.HandlerFunc {
// Initialize schema loader once
schemaLoader := initSchemaLoader()
return func(c *gin.Context) {
// Start tracing span for validation
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "response_validation")
defer span.End()
// Store the original response writer
originalWriter := c.Writer
// Create a custom response writer that captures the response
responseWriter := &responseCaptureWriter{
ResponseWriter: originalWriter,
body: &bytes.Buffer{},
status: 0,
}
// Replace the response writer
c.Writer = responseWriter
// Continue to the next handler
c.Next()
// After the response is written, validate it
statusCode := responseWriter.status
if statusCode == 0 {
statusCode = c.Writer.Status()
}
// Only validate 2xx responses
if statusCode >= http.StatusOK && statusCode < http.StatusMultipleChoices {
// Skip validation for streaming responses
contentType := c.Writer.Header().Get("Content-Type")
if contentType == "text/event-stream" {
span.SetAttributes(
observability.AttributeTypeFilter("streaming_response"),
)
logger.Debug(ctx, "Skipping validation for streaming response", map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
})
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// Try to parse the response as JSON
var responseData interface{}
err := json.Unmarshal(responseWriter.body.Bytes(), &responseData)
if err == nil {
// Determine schema name from the endpoint for the actual status code
schemaName := schemaLoader.DetermineResponseSchemaFromPath(c.Request.URL.Path, c.Request.Method, fmt.Sprintf("%d", statusCode))
if schemaName == "" {
// Fallback to generic success schema resolution if exact status not found
schemaName = schemaLoader.DetermineSchemaFromPath(c.Request.URL.Path, c.Request.Method)
}
// Add tracing attributes
span.SetAttributes(
observability.AttributeSearch(c.Request.URL.Path),
observability.AttributeTypeFilter(c.Request.Method),
)
if schemaName != "" {
span.SetAttributes(observability.AttributeSearch(schemaName))
if err := schemaLoader.ValidateData(responseData, schemaName); err != nil {
// Log the validation error and add tracing attributes
span.SetAttributes(
observability.AttributeTypeFilter("validation_failed"),
)
// Log the validation error and fail the request
logger.Error(ctx, "Response validation failed", err, map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
"schema_name": schemaName,
"error": err.Error(),
"response_data": responseWriter.body.String()[:int(math.Min(200, float64(responseWriter.body.Len())))],
})
// Write a 400 error response instead of the original response
c.Writer = originalWriter
c.Writer.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(c.Writer).Encode(gin.H{
"error": "Response validation failed",
"message": "API response does not match the specification",
"method": c.Request.Method,
"path": c.Request.URL.Path,
"schema": schemaName,
"details": err.Error(),
})
return
}
// Add success tracing attributes
span.SetAttributes(
observability.AttributeTypeFilter("validation_passed"),
)
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// No schema found for this endpoint
span.SetAttributes(
observability.AttributeTypeFilter("no_schema_found"),
)
logger.Warn(ctx, "No schema found for endpoint", map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
})
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// Failed to parse JSON response
span.SetAttributes(
observability.AttributeTypeFilter("json_parse_failed"),
)
logger.Error(ctx, "Failed to parse JSON response", err, map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
})
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// Non-200 status code, skip validation
span.SetAttributes(
observability.AttributeTypeFilter("non_200_status"),
)
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
}
}
// responseCaptureWriter captures the response body for validation
// Add a status field to track the status code
type responseCaptureWriter struct {
gin.ResponseWriter
body *bytes.Buffer
status int
}
func (w *responseCaptureWriter) WriteHeader(statusCode int) {
w.status = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseCaptureWriter) Write(b []byte) (int, error) {
return w.body.Write(b)
}
func (w *responseCaptureWriter) Status() int {
if w.status != 0 {
return w.status
}
return w.ResponseWriter.Status()
}
// isStaticFile checks if a path is a static file that should be allowed to pass through
func isStaticFile(path string) bool {
staticPaths := []string{
"/swagger.yaml",
"/swaggerz",
"/configz",
"/",
}
for _, staticPath := range staticPaths {
if path == staticPath {
return true
}
}
// Also allow paths that start with /backend/ (static assets)
if strings.HasPrefix(path, "/backend/") {
return true
}
return false
}
// RequestValidationMiddleware creates middleware that prevents undocumented API calls
func RequestValidationMiddleware(logger *observability.Logger) gin.HandlerFunc {
// Initialize schema loader once
schemaLoader := initSchemaLoader()
return func(c *gin.Context) {
// Start tracing span for request validation
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "request_validation")
defer span.End()
// Check if the endpoint exists in the swagger spec
path := c.Request.URL.Path
method := c.Request.Method
// Log all requests for debugging
logger.Info(ctx, "Request validation middleware called", map[string]interface{}{
"method": method,
"path": path,
})
// Add tracing attributes
span.SetAttributes(
observability.AttributeSearch(path),
observability.AttributeTypeFilter(method),
)
// Allow static files to pass through
if isStaticFile(path) {
// Continue to the next handler
c.Next()
return
}
// Check if this endpoint is documented in swagger
if !schemaLoader.IsEndpointDocumented(path, method) {
// Log the undocumented API call
logger.Warn(ctx, "Undocumented API call attempted", map[string]interface{}{
"method": method,
"path": path,
"ip": c.ClientIP(),
"user_agent": c.Request.UserAgent(),
})
// Return 404 for undocumented endpoints
c.JSON(http.StatusNotFound, gin.H{
"error": "Endpoint not found",
"message": "The requested endpoint is not documented in the API specification",
})
c.Abort()
return
}
// Endpoint is documented, continue
span.SetAttributes(
observability.AttributeTypeFilter("endpoint_documented"),
)
// Validate request body against schema for POST/PUT/PATCH requests
if method == "POST" || method == "PUT" || method == "PATCH" {
// Determine the request body schema name for this endpoint
schemaName := schemaLoader.DetermineRequestSchemaFromPath(path, method)
// Log the schema determination for debugging
logger.Info(ctx, "Request validation schema determined", map[string]interface{}{
"method": method,
"path": path,
"schema_name": schemaName,
})
// Log when no schema is found
if schemaName == "" {
logger.Warn(ctx, "No schema found for endpoint", map[string]interface{}{
"method": method,
"path": path,
})
}
// Restore the request body so handlers can read it
body, err := c.GetRawData()
if err == nil && len(body) > 0 {
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
if schemaName != "" {
// Read the request body without consuming it
body, err := c.GetRawData()
if err == nil && len(body) > 0 {
// Restore the request body so handlers can read it
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
// Log the raw request body for debugging
logger.Info(ctx, "Request body received", map[string]interface{}{
"method": method,
"path": path,
"schema_name": schemaName,
"body": string(body),
})
// Parse the JSON
var requestData interface{}
if err := json.Unmarshal(body, &requestData); err == nil {
// Validate the request data against the schema
if err := schemaLoader.ValidateData(requestData, schemaName); err != nil {
// Log the validation error and the request data
logger.Error(ctx, "Request validation failed", err, map[string]interface{}{
"method": method,
"path": path,
"schema_name": schemaName,
"error": err.Error(),
"request_data": requestData,
"raw_body": string(body),
})
// Add validation error details to tracing span
span.SetAttributes(
observability.AttributeTypeFilter("validation_failed"),
observability.AttributeSearch(path),
observability.AttributeTypeFilter(method),
observability.AttributeTypeFilter(schemaName),
observability.AttributeTypeFilter("validation_error:"+err.Error()),
observability.AttributeTypeFilter("request_data:"+fmt.Sprintf("%v", requestData)),
observability.AttributeTypeFilter("raw_body:"+string(body)),
)
// Print a concise summary to stdout for test debug
fmt.Printf("\n[VALIDATION ERROR] %v\n[REQUEST DATA] %v\n[RAW BODY] %s\n\n", err, requestData, string(body))
// Return 400 for invalid request data
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid request data",
"message": "Request data does not match the API specification",
"method": method,
"path": path,
"schema": schemaName,
"details": err.Error(),
})
c.Abort()
return
}
}
// Restore the request body so handlers can read it
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
}
// Continue to the next handler
c.Next()
}
}
package models
import (
"database/sql"
"time"
)
// AuthAPIKey represents an API key for programmatic authentication
// This is separate from user_api_keys which stores AI provider API keys
type AuthAPIKey struct {
ID int `json:"id"`
UserID int `json:"user_id"`
KeyName string `json:"key_name"`
KeyHash string `json:"-"` // Never expose the hash
KeyPrefix string `json:"key_prefix"`
PermissionLevel string `json:"permission_level"` // "readonly" or "full"
LastUsedAt sql.NullTime `json:"last_used_at"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PermissionLevel constants
const (
PermissionLevelReadonly = "readonly"
PermissionLevelFull = "full"
)
// IsValidPermissionLevel checks if the permission level is valid
func IsValidPermissionLevel(level string) bool {
return level == PermissionLevelReadonly || level == PermissionLevelFull
}
// CanPerformMethod checks if the permission level allows the given HTTP method
func (k *AuthAPIKey) CanPerformMethod(method string) bool {
if k.PermissionLevel == PermissionLevelFull {
return true
}
// Readonly keys can only perform GET and HEAD requests
return method == "GET" || method == "HEAD"
}
// Package models defines data structures used throughout the quiz application.
package models
import (
"database/sql"
"encoding/json"
"time"
"quizapp/internal/api"
)
// User represents a user in the system
type User struct {
ID int `json:"id" yaml:"id"`
Username string `json:"username" yaml:"username"`
Email sql.NullString `json:"email" yaml:"email"`
Timezone sql.NullString `json:"timezone" yaml:"timezone"`
PasswordHash sql.NullString `json:"-" yaml:"-"` // Omit from JSON responses
LastActive sql.NullTime `json:"last_active" yaml:"last_active"`
PreferredLanguage sql.NullString `json:"preferred_language" yaml:"preferred_language"`
CurrentLevel sql.NullString `json:"current_level" yaml:"current_level"`
AIProvider sql.NullString `json:"ai_provider" yaml:"ai_provider"`
AIModel sql.NullString `json:"ai_model" yaml:"ai_model"`
AIEnabled sql.NullBool `json:"ai_enabled" yaml:"ai_enabled"`
AIAPIKey sql.NullString `json:"-" yaml:"ai_api_key"` // Omit from JSON responses
WordOfDayEmailEnabled sql.NullBool `json:"word_of_day_email_enabled" yaml:"word_of_day_email_enabled"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"`
Roles []Role `json:"roles,omitempty" yaml:"roles,omitempty"`
}
// Role represents a role in the system
type Role struct {
ID int `json:"id" yaml:"id"`
Name string `json:"name" yaml:"name"`
Description string `json:"description" yaml:"description"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"`
}
// UserRole represents the mapping between users and roles
type UserRole struct {
ID int `json:"id" yaml:"id"`
UserID int `json:"user_id" yaml:"user_id"`
RoleID int `json:"role_id" yaml:"role_id"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
}
// Snippet represents a vocabulary snippet saved by a user
type Snippet struct {
ID int64 `json:"id" yaml:"id"`
UserID int64 `json:"user_id" yaml:"user_id"`
OriginalText string `json:"original_text" yaml:"original_text"`
TranslatedText string `json:"translated_text" yaml:"translated_text"`
SourceLanguage string `json:"source_language" yaml:"source_language"`
TargetLanguage string `json:"target_language" yaml:"target_language"`
QuestionID *int64 `json:"question_id" yaml:"question_id"`
SectionID *int64 `json:"section_id" yaml:"section_id"`
StoryID *int64 `json:"story_id" yaml:"story_id"`
Context *string `json:"context" yaml:"context"`
DifficultyLevel *string `json:"difficulty_level" yaml:"difficulty_level"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"`
}
// TranslationCache represents a cached translation result
type TranslationCache struct {
ID int `json:"id" yaml:"id"`
TextHash string `json:"text_hash" yaml:"text_hash"`
OriginalText string `json:"original_text" yaml:"original_text"`
SourceLanguage string `json:"source_language" yaml:"source_language"`
TargetLanguage string `json:"target_language" yaml:"target_language"`
TranslatedText string `json:"translated_text" yaml:"translated_text"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
ExpiresAt time.Time `json:"expires_at" yaml:"expires_at"`
}
// MarshalJSON customizes JSON marshaling for User to handle sql.NullString and sql.NullTime properly
func (u User) MarshalJSON() (result0 []byte, err error) { // Create a struct with the desired JSON structure
return json.Marshal(&struct {
ID int `json:"id"`
Username string `json:"username"`
Email *string `json:"email"`
Timezone *string `json:"timezone"`
LastActive *time.Time `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
CurrentLevel *string `json:"current_level"`
AIProvider *string `json:"ai_provider"`
AIModel *string `json:"ai_model"`
AIEnabled *bool `json:"ai_enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Roles []Role `json:"roles,omitempty"`
}{
ID: u.ID,
Username: u.Username,
Email: nullStringToPointer(u.Email),
Timezone: nullStringToPointer(u.Timezone),
LastActive: nullTimeToPointer(u.LastActive),
PreferredLanguage: nullStringToPointer(u.PreferredLanguage),
CurrentLevel: nullStringToPointer(u.CurrentLevel),
AIProvider: nullStringToPointer(u.AIProvider),
AIModel: nullStringToPointer(u.AIModel),
AIEnabled: nullBoolToPointer(u.AIEnabled),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
Roles: u.Roles,
})
}
// Helper functions for converting sql.Null types to pointers
func nullStringToPointer(ns sql.NullString) *string {
if ns.Valid {
return &ns.String
}
return nil
}
func nullTimeToPointer(nt sql.NullTime) *time.Time {
if nt.Valid {
return &nt.Time
}
return nil
}
func nullBoolToPointer(nb sql.NullBool) *bool {
if nb.Valid {
return &nb.Bool
}
return nil
}
func nullInt32ToPointer(ni sql.NullInt32) *int32 {
if ni.Valid {
return &ni.Int32
}
return nil
}
// UserAPIKey represents an API key for a specific provider for a user
type UserAPIKey struct {
ID int `json:"id"`
UserID int `json:"user_id"`
Provider string `json:"provider"`
APIKey string `json:"-"` // Omit from JSON responses for security
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Question represents a quiz question
type Question struct {
ID int `json:"id" yaml:"id"`
Type QuestionType `json:"type" yaml:"type"`
Language string `json:"language" yaml:"language"`
Level string `json:"level" yaml:"level"`
DifficultyScore float64 `json:"difficulty_score" yaml:"difficulty_score"`
Content map[string]interface{} `json:"content" yaml:"content"`
CorrectAnswer int `json:"correct_answer" yaml:"correct_answer"`
Explanation string `json:"explanation,omitempty" yaml:"explanation"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
Status QuestionStatus `json:"status" yaml:"status"`
// Test data field for specifying which users should have this question
Users []string `json:"users,omitempty" yaml:"users,omitempty"`
// Variety elements for question generation diversity
TopicCategory string `json:"topic_category,omitempty" yaml:"topic_category"`
GrammarFocus string `json:"grammar_focus,omitempty" yaml:"grammar_focus"`
VocabularyDomain string `json:"vocabulary_domain,omitempty" yaml:"vocabulary_domain"`
Scenario string `json:"scenario,omitempty" yaml:"scenario"`
StyleModifier string `json:"style_modifier,omitempty" yaml:"style_modifier"`
DifficultyModifier string `json:"difficulty_modifier,omitempty" yaml:"difficulty_modifier"`
TimeContext string `json:"time_context,omitempty" yaml:"time_context"`
}
// UserQuestion represents the mapping between users and questions
type UserQuestion struct {
ID int `json:"id"`
UserID int `json:"user_id"`
QuestionID int `json:"question_id"`
CreatedAt time.Time `json:"created_at"`
}
// QuestionReport represents a report of a question by a user
type QuestionReport struct {
ID int `json:"id"`
QuestionID int `json:"question_id"`
ReportedByUserID int `json:"reported_by_user_id"`
ReportReason string `json:"report_reason"`
CreatedAt time.Time `json:"created_at"`
}
// QuestionType represents the type of question
type QuestionType string
// QuestionStatus represents the status of a question
type QuestionStatus string
const (
// QuestionStatusActive is for questions that are in active use
QuestionStatusActive QuestionStatus = "active"
// QuestionStatusReported is for questions that have been reported as incorrect
QuestionStatusReported QuestionStatus = "reported"
)
// Question types supported by the system
const (
// Vocabulary represents vocabulary in context questions
Vocabulary QuestionType = "vocabulary"
// FillInBlank represents fill-in-the-blank questions
FillInBlank QuestionType = "fill_blank"
// QuestionAnswer represents simple Q&A questions
QuestionAnswer QuestionType = "qa"
// ReadingComprehension represents reading comprehension questions
ReadingComprehension QuestionType = "reading_comprehension"
)
// UserResponse represents a user's answer to a question
type UserResponse struct {
ID int `json:"id" yaml:"id"`
UserID int `json:"user_id" yaml:"user_id"`
QuestionID int `json:"question_id" yaml:"question_id"`
UserAnswerIndex int `json:"user_answer_index" yaml:"user_answer_index"`
IsCorrect bool `json:"is_correct" yaml:"is_correct"`
ResponseTimeMs int `json:"response_time_ms" yaml:"response_time_ms"`
ConfidenceLevel sql.NullInt32 `json:"confidence_level" yaml:"confidence_level"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
}
// MarshalJSON customizes JSON marshaling for UserResponse to handle sql.NullInt32 properly
func (ur UserResponse) MarshalJSON() (result0 []byte, err error) {
return json.Marshal(&struct {
ID int `json:"id"`
UserID int `json:"user_id"`
QuestionID int `json:"question_id"`
UserAnswerIndex int `json:"user_answer_index"`
IsCorrect bool `json:"is_correct"`
ResponseTimeMs int `json:"response_time_ms"`
ConfidenceLevel *int32 `json:"confidence_level"`
CreatedAt time.Time `json:"created_at"`
}{
ID: ur.ID,
UserID: ur.UserID,
QuestionID: ur.QuestionID,
UserAnswerIndex: ur.UserAnswerIndex,
IsCorrect: ur.IsCorrect,
ResponseTimeMs: ur.ResponseTimeMs,
ConfidenceLevel: nullInt32ToPointer(ur.ConfidenceLevel),
CreatedAt: ur.CreatedAt,
})
}
// PerformanceMetrics tracks user performance across different categories
type PerformanceMetrics struct {
ID int `json:"id"`
UserID int `json:"user_id"`
Topic string `json:"topic"`
Language string `json:"language"`
Level string `json:"level"`
TotalAttempts int `json:"total_attempts"`
CorrectAttempts int `json:"correct_attempts"`
AverageResponseTimeMs float64 `json:"average_response_time_ms"`
DifficultyAdjustment float64 `json:"difficulty_adjustment"`
LastUpdated time.Time `json:"last_updated"`
}
// AccuracyRate calculates the accuracy percentage
func (pm *PerformanceMetrics) AccuracyRate() float64 {
if pm.TotalAttempts == 0 {
return 0.0
}
return float64(pm.CorrectAttempts) / float64(pm.TotalAttempts) * 100
}
// QuestionRequest represents a request for a new question
type QuestionRequest struct {
UserID int `json:"user_id"`
Language string `json:"language"`
Level string `json:"level"`
QuestionType QuestionType `json:"question_type,omitempty"`
}
// AnswerRequest represents a user's answer submission
type AnswerRequest struct {
QuestionID int `json:"question_id"`
UserAnswer string `json:"user_answer"`
ResponseTimeMs int `json:"response_time_ms"`
}
// AnswerResponse represents the response to an answer submission
type AnswerResponse struct {
IsCorrect bool `json:"is_correct"`
CorrectAnswer string `json:"correct_answer"`
UserAnswer string `json:"user_answer"`
Explanation string `json:"explanation"`
NextDifficulty string `json:"next_difficulty,omitempty"`
}
// GetCorrectAnswerText returns the text of the correct answer from the question content
func (q *Question) GetCorrectAnswerText() string {
if optionsRaw, ok := q.Content["options"]; ok {
if options, ok := optionsRaw.([]interface{}); ok {
if q.CorrectAnswer >= 0 && q.CorrectAnswer < len(options) {
if optStr, ok := options[q.CorrectAnswer].(string); ok {
return optStr
}
}
}
}
return ""
}
// UserSettings represents user preference settings
type UserSettings struct {
Language string `json:"language" yaml:"language"`
Level string `json:"level" yaml:"level"`
AIProvider string `json:"ai_provider" yaml:"ai_provider"`
AIModel string `json:"ai_model" yaml:"ai_model"`
AIEnabled bool `json:"ai_enabled" yaml:"ai_enabled"`
AIAPIKey string `json:"api_key" yaml:"ai_api_key"`
}
// UserLearningPreferences represents user learning preferences and settings
type UserLearningPreferences struct {
ID int `json:"id" db:"id"`
UserID int `json:"user_id" db:"user_id"`
PreferredLanguage string `json:"preferred_language" db:"preferred_language"`
CurrentLevel string `json:"current_level" db:"current_level"`
AIProvider string `json:"ai_provider" db:"ai_provider"`
AIModel string `json:"ai_model" db:"ai_model"`
AIEnabled bool `json:"ai_enabled" db:"ai_enabled"`
AIAPIKey string `json:"-" db:"ai_api_key"` // Omit from JSON for security
DailyGoal int `json:"daily_goal" db:"daily_goal"`
WeeklyGoal int `json:"weekly_goal" db:"weekly_goal"`
PreferredQuestionType string `json:"preferred_question_type" db:"preferred_question_type"`
PreferredQuestionTypes []string `json:"preferred_question_types" db:"preferred_question_types"`
PreferredDifficultyLevel string `json:"preferred_difficulty_level" db:"preferred_difficulty_level"`
PreferredTopics []string `json:"preferred_topics" db:"preferred_topics"`
PreferredQuestionCount int `json:"preferred_question_count" db:"preferred_question_count"`
SpacedRepetitionEnabled bool `json:"spaced_repetition_enabled" db:"spaced_repetition_enabled"`
AdaptiveDifficultyEnabled bool `json:"adaptive_difficulty_enabled" db:"adaptive_difficulty_enabled"`
FocusOnWeakAreas bool `json:"focus_on_weak_areas" db:"focus_on_weak_areas"`
IncludeReviewQuestions bool `json:"include_review_questions" db:"include_review_questions"`
FreshQuestionRatio float64 `json:"fresh_question_ratio" db:"fresh_question_ratio"`
KnownQuestionPenalty float64 `json:"known_question_penalty" db:"known_question_penalty"`
ReviewIntervalDays int `json:"review_interval_days" db:"review_interval_days"`
WeakAreaBoost float64 `json:"weak_area_boost" db:"weak_area_boost"`
StudyTime string `json:"study_time" db:"study_time"`
DailyReminderEnabled bool `json:"daily_reminder_enabled" db:"daily_reminder_enabled"`
// Preferred TTS voice (e.g., it-IT-IsabellaNeural)
TTSVoice string `json:"tts_voice" db:"tts_voice"`
LastDailyReminderSent *time.Time `json:"last_daily_reminder_sent" db:"last_daily_reminder_sent"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// UserProgress represents a user's overall progress
type UserProgress struct {
CurrentLevel string `json:"current_level"`
TotalQuestions int `json:"total_questions"`
CorrectAnswers int `json:"correct_answers"`
AccuracyRate float64 `json:"accuracy_rate"`
PerformanceByTopic map[string]*PerformanceMetrics `json:"performance_by_topic"`
WeakAreas []string `json:"weak_areas"`
RecentActivity []UserResponse `json:"recent_activity"`
SuggestedLevel string `json:"suggested_level,omitempty"`
}
// AIQuestionGenRequest represents a request to the AI service for question generation
type AIQuestionGenRequest struct {
Language string `json:"language"`
Level string `json:"level"`
QuestionType QuestionType `json:"question_type"`
Count int `json:"count"`
RecentQuestionHistory []string `json:"-"` // Don't include in JSON, internal use
}
// AIChatRequest represents a request to the AI service for a new chat feature
type AIChatRequest struct {
Language string
Level string
QuestionType QuestionType // Question type for context
Question string
Options []string
Passage string // For reading comprehension
UserAnswer string // Optional
CorrectAnswer string // Optional
IsCorrect *bool // Optional
UserMessage string
ConversationHistory []ChatMessage `json:"conversation_history,omitempty"`
RecentQuestionHistory []string `json:"-"` // Don't include in JSON, internal use
}
// ChatMessage represents a single message in the chat conversation
type ChatMessage struct {
Role api.ChatMessageRole `json:"role"` // "user" or "assistant"
Content string `json:"content"` // The message content
}
// AIExplanationRequest represents a request for an explanation of a wrong answer
type AIExplanationRequest struct {
Question string `json:"question"`
UserAnswer string `json:"user_answer"`
CorrectAnswer string `json:"correct_answer"`
Language string `json:"language"`
Level string `json:"level"`
}
// MarshalContentToJSON serializes the question content to JSON string
func (q *Question) MarshalContentToJSON() (result0 string, err error) {
// Clean up fields that should be at the top level, not in content
// Remove fields that are not allowed in QuestionContent according to OpenAPI schema
if q.Content != nil {
// Always remove correct_answer from content as it should be at top level
delete(q.Content, "correct_answer")
// Always remove explanation from content as it should be at top level
delete(q.Content, "explanation")
}
data, err := json.Marshal(q.Content)
return string(data), err
}
// UnmarshalContentFromJSON deserializes JSON string into question content
func (q *Question) UnmarshalContentFromJSON(data string) error {
err := json.Unmarshal([]byte(data), &q.Content)
if err != nil {
return err
}
// Clean up fields that should be at the top level, not in content
// Remove fields that are not allowed in QuestionContent according to OpenAPI schema
if q.Content != nil {
// Always remove correct_answer from content as it should be at top level
delete(q.Content, "correct_answer")
// Always remove explanation from content as it should be at top level
delete(q.Content, "explanation")
}
return nil
}
// WorkerSettings represents worker configuration settings stored in database
type WorkerSettings struct {
ID int `json:"id" db:"id"`
SettingKey string `json:"setting_key" db:"setting_key"`
SettingValue string `json:"setting_value" db:"setting_value"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// WorkerStatus represents worker health and activity status
type WorkerStatus struct {
ID int `json:"id" db:"id"`
WorkerInstance string `json:"worker_instance" db:"worker_instance"`
IsRunning bool `json:"is_running" db:"is_running"`
IsPaused bool `json:"is_paused" db:"is_paused"`
CurrentActivity sql.NullString `json:"current_activity" db:"current_activity"`
LastHeartbeat sql.NullTime `json:"last_heartbeat" db:"last_heartbeat"`
LastRunStart sql.NullTime `json:"last_run_start" db:"last_run_start"`
LastRunEnd sql.NullTime `json:"last_run_end" db:"last_run_end"`
LastRunFinish sql.NullTime `json:"last_run_finish" db:"last_run_finish"`
LastRunError sql.NullString `json:"last_run_error" db:"last_run_error"`
TotalQuestionsProcessed int `json:"total_questions_processed" db:"total_questions_processed"`
TotalQuestionsGenerated int `json:"total_questions_generated" db:"total_questions_generated"`
TotalRuns int `json:"total_runs" db:"total_runs"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// MarshalJSON customizes JSON marshaling for WorkerStatus to handle sql.NullString and sql.NullTime properly
func (ws WorkerStatus) MarshalJSON() (result0 []byte, err error) {
return json.Marshal(&struct {
ID int `json:"id"`
WorkerInstance string `json:"worker_instance"`
IsRunning bool `json:"is_running"`
IsPaused bool `json:"is_paused"`
CurrentActivity *string `json:"current_activity"`
LastHeartbeat *time.Time `json:"last_heartbeat"`
LastRunStart *time.Time `json:"last_run_start"`
LastRunEnd *time.Time `json:"last_run_end"`
LastRunFinish *time.Time `json:"last_run_finish"`
LastRunError *string `json:"last_run_error"`
TotalQuestionsProcessed int `json:"total_questions_processed"`
TotalQuestionsGenerated int `json:"total_questions_generated"`
TotalRuns int `json:"total_runs"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}{
ID: ws.ID,
WorkerInstance: ws.WorkerInstance,
IsRunning: ws.IsRunning,
IsPaused: ws.IsPaused,
CurrentActivity: nullStringToPointer(ws.CurrentActivity),
LastHeartbeat: nullTimeToPointer(ws.LastHeartbeat),
LastRunStart: nullTimeToPointer(ws.LastRunStart),
LastRunEnd: nullTimeToPointer(ws.LastRunEnd),
LastRunFinish: nullTimeToPointer(ws.LastRunFinish),
LastRunError: nullStringToPointer(ws.LastRunError),
TotalQuestionsProcessed: ws.TotalQuestionsProcessed,
TotalQuestionsGenerated: ws.TotalQuestionsGenerated,
TotalRuns: ws.TotalRuns,
CreatedAt: ws.CreatedAt,
UpdatedAt: ws.UpdatedAt,
})
}
package models
import (
"errors"
"strings"
"time"
)
// StoryStatus represents the status of a story
type StoryStatus string
// Story status constants
const (
StoryStatusActive StoryStatus = "active" // StoryStatusActive represents an active story
StoryStatusArchived StoryStatus = "archived" // StoryStatusArchived represents an archived story
StoryStatusCompleted StoryStatus = "completed" // StoryStatusCompleted represents a completed story
)
// SectionLength represents the preferred length of story sections
type SectionLength string
// Section length constants
const (
SectionLengthShort SectionLength = "short" // SectionLengthShort represents a short section length
SectionLengthMedium SectionLength = "medium" // SectionLengthMedium represents a medium section length
SectionLengthLong SectionLength = "long" // SectionLengthLong represents a long section length
)
// GeneratorType represents who generated a story section
type GeneratorType string
// Generator type constants
const (
GeneratorTypeWorker GeneratorType = "worker" // GeneratorTypeWorker represents worker-generated sections
GeneratorTypeUser GeneratorType = "user" // GeneratorTypeUser represents user-generated sections
)
// Story represents a user-created story with metadata
type Story struct {
ID uint `json:"id"`
UserID uint `json:"user_id"`
Title string `json:"title"`
Language string `json:"language"`
Subject *string `json:"subject"`
AuthorStyle *string `json:"author_style"`
TimePeriod *string `json:"time_period"`
Genre *string `json:"genre"`
Tone *string `json:"tone"`
CharacterNames *string `json:"character_names"`
CustomInstructions *string `json:"custom_instructions"`
SectionLengthOverride *SectionLength `json:"section_length_override,omitempty"`
Status StoryStatus `json:"status"`
AutoGenerationPaused bool `json:"auto_generation_paused"`
LastSectionGeneratedAt *time.Time `json:"last_section_generated_at"`
ExtraGenerationsToday int `json:"extra_generations_today"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
// Relationships
User User `json:"user,omitempty"`
Sections []StorySection `json:"sections,omitempty"`
}
// GetSectionLengthOverride returns the section length override as a string, handling nil pointers
func (s *Story) GetSectionLengthOverride() string {
if s.SectionLengthOverride == nil {
return ""
}
return string(*s.SectionLengthOverride)
}
// StorySection represents an individual section of a story
type StorySection struct {
ID uint `json:"id"`
StoryID uint `json:"story_id"`
SectionNumber int `json:"section_number"`
Content string `json:"content"`
LanguageLevel string `json:"language_level"`
WordCount int `json:"word_count"`
GeneratedBy GeneratorType `json:"generated_by"`
GeneratedAt time.Time `json:"generated_at"`
GenerationDate time.Time `json:"generation_date"`
// Relationships
Story Story `json:"story,omitempty"`
Questions []StorySectionQuestion `json:"questions,omitempty"`
}
// StorySectionQuestion represents a comprehension question for a story section
type StorySectionQuestion struct {
ID uint `json:"id"`
SectionID uint `json:"section_id"`
QuestionText string `json:"question_text"`
Options []string `json:"options"`
CorrectAnswerIndex int `json:"correct_answer_index"`
Explanation *string `json:"explanation"`
CreatedAt time.Time `json:"created_at"`
// Relationships
Section StorySection `json:"section,omitempty"`
}
// StoryWithSections represents a story with all its sections loaded
type StoryWithSections struct {
Story
Sections []StorySection `json:"sections"`
}
// StorySectionWithQuestions represents a section with all its questions loaded
type StorySectionWithQuestions struct {
StorySection
Questions []StorySectionQuestion `json:"questions"`
}
// CreateStoryRequest represents the request to create a new story
type CreateStoryRequest struct {
Title string `json:"title" validate:"required,min=1,max=200"`
Subject *string `json:"subject" validate:"omitempty,max=500"`
AuthorStyle *string `json:"author_style" validate:"omitempty,max=200"`
TimePeriod *string `json:"time_period" validate:"omitempty,max=200"`
Genre *string `json:"genre" validate:"omitempty,max=100"`
Tone *string `json:"tone" validate:"omitempty,max=100"`
CharacterNames *string `json:"character_names" validate:"omitempty,max=1000"`
CustomInstructions *string `json:"custom_instructions" validate:"omitempty,max=2000"`
SectionLengthOverride *SectionLength `json:"section_length_override" validate:"omitempty,oneof=short medium long"`
}
// StoryGenerationRequest represents the request for AI story generation
type StoryGenerationRequest struct {
UserID uint `json:"-"`
StoryID uint `json:"-"`
Language string `json:"language"`
Level string `json:"level"`
Title string `json:"title"`
Subject *string `json:"subject,omitempty"`
AuthorStyle *string `json:"author_style,omitempty"`
TimePeriod *string `json:"time_period,omitempty"`
Genre *string `json:"genre,omitempty"`
Tone *string `json:"tone,omitempty"`
CharacterNames *string `json:"character_names,omitempty"`
CustomInstructions *string `json:"custom_instructions,omitempty"`
SectionLength SectionLength `json:"section_length"`
PreviousSections string `json:"previous_sections"`
IsFirstSection bool `json:"is_first_section"`
TargetWords int `json:"target_words"`
TargetSentences int `json:"target_sentences"`
}
// StoryQuestionsRequest represents the request for AI question generation
type StoryQuestionsRequest struct {
UserID uint `json:"-"`
SectionID uint `json:"-"`
Language string `json:"language"`
Level string `json:"level"`
SectionText string `json:"section_text"`
QuestionCount int `json:"question_count"`
}
// StorySectionQuestionData represents the structure returned by AI for questions
type StorySectionQuestionData struct {
QuestionText string `json:"question_text"`
Options []string `json:"options"`
CorrectAnswerIndex int `json:"correct_answer_index"`
Explanation *string `json:"explanation"`
}
// Validate validates the CreateStoryRequest
func (r *CreateStoryRequest) Validate() error {
if r.Title == "" {
return errors.New("title is required")
}
if len(r.Title) > 200 {
return errors.New("title must be 200 characters or less")
}
if r.Subject != nil && len(*r.Subject) > 500 {
return errors.New("subject must be 500 characters or less")
}
if r.AuthorStyle != nil && len(*r.AuthorStyle) > 200 {
return errors.New("author style must be 200 characters or less")
}
if r.TimePeriod != nil && len(*r.TimePeriod) > 200 {
return errors.New("time period must be 200 characters or less")
}
if r.Genre != nil && len(*r.Genre) > 100 {
return errors.New("genre must be 100 characters or less")
}
if r.Tone != nil && len(*r.Tone) > 100 {
return errors.New("tone must be 100 characters or less")
}
if r.CharacterNames != nil && len(*r.CharacterNames) > 1000 {
return errors.New("character names must be 1000 characters or less")
}
if r.CustomInstructions != nil && len(*r.CustomInstructions) > 2000 {
return errors.New("custom instructions must be 2000 characters or less")
}
if r.SectionLengthOverride != nil {
switch *r.SectionLengthOverride {
case SectionLengthShort, SectionLengthMedium, SectionLengthLong:
// Valid
default:
return errors.New("section length override must be one of: short, medium, long")
}
}
return nil
}
// SanitizeInput sanitizes user input for safe use in AI prompts
func SanitizeInput(input string) string {
// Basic sanitization - remove control characters and trim whitespace
// In a production system, you might want more sophisticated sanitization
result := strings.TrimSpace(input)
// Remove null bytes and control characters
for i := 0; i < len(result); i++ {
if result[i] < 32 && result[i] != 9 && result[i] != 10 && result[i] != 13 {
result = result[:i] + result[i+1:]
i--
}
}
return result
}
// UserAIConfig holds per-user AI configuration
type UserAIConfig struct {
Provider string
Model string
APIKey string
Username string // For logging purposes
}
// StoryGenerationEligibilityResponse represents the result of checking if a story section can be generated
type StoryGenerationEligibilityResponse struct {
CanGenerate bool `json:"can_generate"`
Reason string `json:"reason,omitempty"`
Story *Story `json:"story,omitempty"` // Include story data when needed for additional checks
}
// GetSectionLengthTarget returns the target word count for a story section
func GetSectionLengthTarget(level string, lengthPref *SectionLength) int {
// Map CEFR levels to generic proficiency levels for backward compatibility
levelMapping := map[string]string{
"A1": "beginner",
"A2": "elementary",
"B1": "intermediate",
"B2": "upper_intermediate",
"C1": "advanced",
"C2": "proficient",
}
genericLevel := levelMapping[level]
if genericLevel == "" {
// If no mapping found, default to intermediate
genericLevel = "intermediate"
}
// Default length targets by proficiency level (in words)
lengthTargets := map[string]map[SectionLength]int{
"beginner": {SectionLengthShort: 50, SectionLengthMedium: 80, SectionLengthLong: 120},
"elementary": {SectionLengthShort: 80, SectionLengthMedium: 120, SectionLengthLong: 180},
"intermediate": {SectionLengthShort: 150, SectionLengthMedium: 220, SectionLengthLong: 300},
"upper_intermediate": {SectionLengthShort: 250, SectionLengthMedium: 350, SectionLengthLong: 450},
"advanced": {SectionLengthShort: 350, SectionLengthMedium: 500, SectionLengthLong: 650},
"proficient": {SectionLengthShort: 500, SectionLengthMedium: 700, SectionLengthLong: 900},
}
levelTargets, exists := lengthTargets[genericLevel]
if !exists {
// Default to intermediate if level not found
levelTargets = lengthTargets["intermediate"]
}
if lengthPref != nil {
if target, exists := levelTargets[*lengthPref]; exists {
return target
}
}
// Default to medium length
return levelTargets[SectionLengthMedium]
}
package models
import (
"encoding/json"
"time"
)
// WordSourceType represents the type of source for the word of the day
type WordSourceType string
const (
// WordSourceVocabularyQuestion represents a word from a vocabulary question
WordSourceVocabularyQuestion WordSourceType = "vocabulary_question"
// WordSourceSnippet represents a word from a user snippet
WordSourceSnippet WordSourceType = "snippet"
)
// WordOfTheDay represents a daily word assignment for a user
type WordOfTheDay struct {
ID int `json:"id" db:"id"`
UserID int `json:"user_id" db:"user_id"`
AssignmentDate time.Time `json:"assignment_date" db:"assignment_date"`
SourceType WordSourceType `json:"source_type" db:"source_type"`
SourceID int `json:"source_id" db:"source_id"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
}
// WordOfTheDayWithContent represents a word of the day with full content details
type WordOfTheDayWithContent struct {
WordOfTheDay
// Question is populated when SourceType is WordSourceVocabularyQuestion
Question *Question `json:"question,omitempty"`
// Snippet is populated when SourceType is WordSourceSnippet
Snippet *Snippet `json:"snippet,omitempty"`
}
// WordOfTheDayDisplay represents the simplified display format for word of the day
// This is used for API responses and contains the essential information
type WordOfTheDayDisplay struct {
Date time.Time `json:"date"`
Word string `json:"word"`
Translation string `json:"translation"`
Sentence string `json:"sentence"`
SourceType WordSourceType `json:"source_type"`
SourceID int `json:"source_id"`
Language string `json:"language"`
Level string `json:"level,omitempty"`
Context string `json:"context,omitempty"`
Explanation string `json:"explanation,omitempty"`
TopicCategory string `json:"topic_category,omitempty"`
}
// MarshalJSON customizes JSON marshaling for WordOfTheDayDisplay to format the date field as YYYY-MM-DD
// This ensures compliance with OpenAPI date format (not date-time)
func (w WordOfTheDayDisplay) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
Date string `json:"date"`
Word string `json:"word"`
Translation string `json:"translation"`
Sentence string `json:"sentence"`
SourceType WordSourceType `json:"source_type"`
SourceID int `json:"source_id"`
Language string `json:"language"`
Level string `json:"level,omitempty"`
Context string `json:"context,omitempty"`
Explanation string `json:"explanation,omitempty"`
TopicCategory string `json:"topic_category,omitempty"`
}{
Date: w.Date.UTC().Format("2006-01-02"),
Word: w.Word,
Translation: w.Translation,
Sentence: w.Sentence,
SourceType: w.SourceType,
SourceID: w.SourceID,
Language: w.Language,
Level: w.Level,
Context: w.Context,
Explanation: w.Explanation,
TopicCategory: w.TopicCategory,
})
}
package observability
import (
"context"
"fmt"
"quizapp/internal/models"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var globalTracer trace.Tracer
// InitGlobalTracer initializes the global tracer for the application.
func InitGlobalTracer() {
globalTracer = otel.Tracer("quiz-app")
}
// GetGlobalTracer returns the global tracer instance for the application.
func GetGlobalTracer() trace.Tracer {
if globalTracer == nil {
// Fallback to default tracer if not initialized
globalTracer = otel.Tracer("quiz-app")
}
return globalTracer
}
// TraceFunction starts a new span with a descriptive name for the given service and function.
func TraceFunction(ctx context.Context, serviceName, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
tracer := GetGlobalTracer()
spanName := fmt.Sprintf("%s.%s", serviceName, functionName)
return tracer.Start(ctx, spanName, trace.WithAttributes(attributes...))
}
// TraceFunctionWithErrorHandling starts a new span and automatically adds error attributes if the function panics or returns an error.
func TraceFunctionWithErrorHandling(ctx context.Context, serviceName, functionName string, fn func() error, attributes ...attribute.KeyValue) error {
_, span := TraceFunction(ctx, serviceName, functionName, attributes...)
defer func() {
if err := recover(); err != nil {
span.SetAttributes(
attribute.Bool("error", true),
attribute.String("error.type", "panic"),
attribute.String("error.message", fmt.Sprintf("%v", err)),
)
span.End()
panic(err) // re-panic
}
}()
err := fn()
if err != nil {
span.SetAttributes(
attribute.Bool("error", true),
attribute.String("error.message", err.Error()),
)
}
span.End()
return err
}
// TraceSnippetFunction starts a new span for a snippet service function.
func TraceSnippetFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "snippet", functionName, attributes...)
}
// TraceTranslationFunction starts a new span for a translation service function.
func TraceTranslationFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "translation", functionName, attributes...)
}
// TraceAIFunction starts a new span for an AI service function.
func TraceAIFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "ai", functionName, attributes...)
}
// TraceUserFunction starts a new span for a user service function.
func TraceUserFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "user", functionName, attributes...)
}
// TraceQuestionFunction starts a new span for a question service function.
func TraceQuestionFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "question", functionName, attributes...)
}
// TraceWorkerFunction starts a new span for a worker service function.
func TraceWorkerFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "worker", functionName, attributes...)
}
// TraceLearningFunction starts a new span for a learning service function.
func TraceLearningFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "learning", functionName, attributes...)
}
// TraceHandlerFunction starts a new span for a handler function.
func TraceHandlerFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "handler", functionName, attributes...)
}
// TraceVarietyFunction starts a new span for a variety service function.
func TraceVarietyFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "variety", functionName, attributes...)
}
// TraceOAuthFunction starts a new span for an OAuth service function.
func TraceOAuthFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "oauth", functionName, attributes...)
}
// TraceUsageStatsFunction starts a new span for a usage stats service function.
func TraceUsageStatsFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "usage_stats", functionName, attributes...)
}
// TraceCleanupFunction starts a new span for a cleanup service function.
func TraceCleanupFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "cleanup", functionName, attributes...)
}
// TraceDatabaseFunction starts a new span for a database function.
func TraceDatabaseFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "database", functionName, attributes...)
}
// AttributeQuestion returns a tracing attribute for a question's ID.
func AttributeQuestion(q *models.Question) attribute.KeyValue {
return attribute.String("question.id", fmt.Sprintf("%d", q.ID))
}
// AttributeQuestionID returns a tracing attribute for a question ID.
func AttributeQuestionID(id int) attribute.KeyValue {
return attribute.Int("question.id", id)
}
// AttributeUserID returns a tracing attribute for a user ID.
func AttributeUserID(id int) attribute.KeyValue {
return attribute.Int("user.id", id)
}
// AttributeSnippetID returns a tracing attribute for a snippet ID.
func AttributeSnippetID(id int) attribute.KeyValue {
return attribute.Int("snippet.id", id)
}
// AttributeLanguage returns a tracing attribute for a language.
func AttributeLanguage(lang string) attribute.KeyValue {
return attribute.String("language", lang)
}
// AttributeGenerationType returns a tracing attribute for a generation type.
func AttributeGenerationType(generationType models.GeneratorType) attribute.KeyValue {
return attribute.String("generation_type", string(generationType))
}
// AttributeLevel returns a tracing attribute for a level.
func AttributeLevel(level string) attribute.KeyValue {
return attribute.String("level", level)
}
// AttributeQuestionType returns a tracing attribute for a question type.
func AttributeQuestionType(qType interface{}) attribute.KeyValue {
return attribute.String("question.type", fmt.Sprintf("%v", qType))
}
// AttributeLimit returns a tracing attribute for a limit value.
func AttributeLimit(limit int) attribute.KeyValue {
return attribute.Int("limit", limit)
}
// AttributePage returns a tracing attribute for a page value.
func AttributePage(page int) attribute.KeyValue {
return attribute.Int("page", page)
}
// AttributePageSize returns a tracing attribute for a page size value.
func AttributePageSize(size int) attribute.KeyValue {
return attribute.Int("page_size", size)
}
// AttributeSearch returns a tracing attribute for a search value.
func AttributeSearch(search string) attribute.KeyValue {
return attribute.String("search", search)
}
// AttributeTypeFilter returns a tracing attribute for a type filter value.
func AttributeTypeFilter(typeFilter string) attribute.KeyValue {
return attribute.String("type_filter", typeFilter)
}
// AttributeStatusFilter returns a tracing attribute for a status filter value.
func AttributeStatusFilter(statusFilter string) attribute.KeyValue {
return attribute.String("status_filter", statusFilter)
}
// Package observability provides OpenTelemetry tracing, metrics, and structured logging
// with trace correlation for the quiz application.
package observability
import (
"context"
"os"
"quizapp/internal/config"
"go.opentelemetry.io/contrib/bridges/otelzap"
"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
"go.opentelemetry.io/otel/sdk/log"
"go.opentelemetry.io/otel/sdk/resource"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Logger wraps the zap logger with OpenTelemetry context support
type Logger struct {
*zap.Logger
}
// NewLogger creates a new logger with OpenTelemetry context support and OTLP export
func NewLogger(cfg *config.OpenTelemetryConfig) *Logger {
return NewLoggerWithLevel(cfg, zap.InfoLevel)
}
// NewLoggerWithLevel creates a new logger with OpenTelemetry context support and OTLP export
func NewLoggerWithLevel(cfg *config.OpenTelemetryConfig, level zapcore.Level) *Logger {
// If logging is disabled, return a no-op logger
if cfg == nil || !cfg.EnableLogging {
return &Logger{Logger: zap.NewNop()}
}
// Create a basic zap logger for stdout
zapConfig := zap.NewProductionConfig()
zapConfig.Level = zap.NewAtomicLevelAt(level)
zapConfig.EncoderConfig.TimeKey = "timestamp"
zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
zapConfig.EncoderConfig.StacktraceKey = "stacktrace"
// Use development config if in development mode
if os.Getenv("ENV") == "development" {
zapConfig = zap.NewDevelopmentConfig()
zapConfig.Level = zap.NewAtomicLevelAt(level)
}
zapLogger, err := zapConfig.Build()
if err != nil {
// Fallback to a basic logger if config fails
zapLogger = zap.NewExample()
}
// If OTLP logging is enabled, set up the OTLP exporter
if cfg.EnableLogging && cfg.Endpoint != "" {
// Log that we're attempting to set up OTLP export
zapLogger.Info("Setting up OTLP logging", zap.String("endpoint", cfg.Endpoint), zap.String("protocol", cfg.Protocol))
// Create OTLP exporter with proper endpoint format
endpoint := cfg.Endpoint
// Set up resource attributes
res, err := resource.New(context.Background(),
resource.WithAttributes(
semconv.ServiceName(cfg.ServiceName),
semconv.ServiceVersion(cfg.ServiceVersion),
),
)
if err != nil {
// Log the error but continue with stdout logging
zapLogger.Error("Failed to create otel resource", zap.Error(err))
} else {
exporter, err := otlploggrpc.New(context.Background(),
otlploggrpc.WithEndpoint(endpoint),
otlploggrpc.WithInsecure(),
)
if err != nil {
// Log the error but continue with stdout logging
zapLogger.Error("Failed to create OTLP exporter", zap.Error(err), zap.String("endpoint", endpoint))
} else {
zapLogger.Info("Successfully created OTLP exporter", zap.String("endpoint", endpoint))
// Create batch processor
processor := log.NewBatchProcessor(exporter)
// Create logger provider with resource
provider := log.NewLoggerProvider(
log.WithProcessor(processor),
log.WithResource(res),
)
// Create OpenTelemetry core
otelCore := otelzap.NewCore("quizapp", otelzap.WithLoggerProvider(provider))
// Create a new zap logger with both stdout and OTLP cores
cores := []zapcore.Core{
zapLogger.Core(),
otelCore,
}
// Create a new logger with multiple cores
multiCore := zapcore.NewTee(cores...)
zapLogger = zap.New(multiCore)
zapLogger.Info("OTLP logging successfully configured", zap.String("endpoint", endpoint))
}
}
} else {
zapLogger.Info("OTLP logging not enabled", zap.Bool("enable_logging", cfg.EnableLogging), zap.String("endpoint", cfg.Endpoint))
}
return &Logger{Logger: zapLogger}
}
// Debug logs a debug message with context
func (l *Logger) Debug(ctx context.Context, msg string, fields ...map[string]interface{}) {
l.logWithContext(ctx, zap.DebugLevel, msg, fields...)
}
// Info logs an info message with context
func (l *Logger) Info(ctx context.Context, msg string, fields ...map[string]interface{}) {
l.logWithContext(ctx, zap.InfoLevel, msg, fields...)
}
// Warn logs a warning message with context
func (l *Logger) Warn(ctx context.Context, msg string, fields ...map[string]interface{}) {
l.logWithContext(ctx, zap.WarnLevel, msg, fields...)
}
// Error logs an error message with context
func (l *Logger) Error(ctx context.Context, msg string, err error, fields ...map[string]interface{}) {
// Merge fields with error information
allFields := l.mergeFields(fields...)
if err != nil {
allFields["error"] = err.Error()
}
l.logWithContext(ctx, zap.ErrorLevel, msg, allFields)
}
// logWithContext logs a message with OpenTelemetry context correlation
func (l *Logger) logWithContext(_ context.Context, level zapcore.Level, msg string, fields ...map[string]interface{}) {
// Merge all fields into a single map
allFields := l.mergeFields(fields...)
// Convert fields to zap fields
zapFields := make([]zap.Field, 0, len(allFields))
for k, v := range allFields {
zapFields = append(zapFields, zap.Any(k, v))
}
// Log with the appropriate level
switch level {
case zap.DebugLevel:
l.Logger.Debug(msg, zapFields...)
case zap.InfoLevel:
l.Logger.Info(msg, zapFields...)
case zap.WarnLevel:
l.Logger.Warn(msg, zapFields...)
case zap.ErrorLevel:
l.Logger.Error(msg, zapFields...)
default:
l.Logger.Info(msg, zapFields...)
}
}
// mergeFields merges multiple field maps into a single map
func (l *Logger) mergeFields(fields ...map[string]interface{}) map[string]interface{} {
if len(fields) == 0 {
return map[string]interface{}{}
}
if len(fields) == 1 {
// Handle nil field map
if fields[0] == nil {
return map[string]interface{}{}
}
return fields[0]
}
// Merge multiple field maps
merged := make(map[string]interface{})
for _, fieldMap := range fields {
// Skip nil field maps
if fieldMap == nil {
continue
}
for k, v := range fieldMap {
merged[k] = v
}
}
return merged
}
// Sync flushes any buffered log entries
func (l *Logger) Sync() error {
return l.Logger.Sync()
}
package observability
import (
"context"
"quizapp/internal/config"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)
// InitMetrics initializes OpenTelemetry metrics
func InitMetrics(cfg *config.OpenTelemetryConfig) (result0 *metric.MeterProvider, err error) {
ctx := context.Background()
// Set up resource attributes
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName(cfg.ServiceName),
semconv.ServiceVersion(cfg.ServiceVersion),
),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otel resource: %w", err)
}
// Set up exporter
var exporter metric.Exporter
switch cfg.Protocol {
case "grpc":
// For gRPC, strip http:// prefix if present, otherwise use endpoint as-is
endpoint := cfg.Endpoint
exp, err := otlpmetricgrpc.New(ctx,
otlpmetricgrpc.WithEndpoint(endpoint),
func() otlpmetricgrpc.Option {
if cfg.Insecure {
return otlpmetricgrpc.WithInsecure()
}
return nil
}(),
otlpmetricgrpc.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp grpc metric exporter: %w", err)
}
exporter = exp
case "http":
exp, err := otlpmetrichttp.New(ctx,
otlpmetrichttp.WithEndpoint(cfg.Endpoint),
func() otlpmetrichttp.Option {
if cfg.Insecure {
return otlpmetrichttp.WithInsecure()
}
return nil
}(),
otlpmetrichttp.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp http metric exporter: %w", err)
}
exporter = exp
default:
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "unsupported otel protocol: %s", cfg.Protocol)
}
// Set up meter provider
mp := metric.NewMeterProvider(
metric.WithReader(metric.NewPeriodicReader(exporter)),
metric.WithResource(res),
)
return mp, nil
}
package observability
import (
"errors"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
contextutils "quizapp/internal/utils"
)
// GinMiddleware creates OpenTelemetry middleware for Gin HTTP requests
func GinMiddleware(serviceName string) gin.HandlerFunc {
return otelgin.Middleware(serviceName)
}
// GinMiddlewareWithErrorHandling creates OpenTelemetry middleware with automatic error attribute addition and detailed logging
func GinMiddlewareWithErrorHandling(serviceName string) gin.HandlerFunc {
return func(c *gin.Context) {
// Use the existing OpenTelemetry middleware
otelgin.Middleware(serviceName)(c)
// After the request is processed, check for errors
c.Next()
// Get the span from context and add error attributes for failed requests
if span := trace.SpanFromContext(c.Request.Context()); span != nil {
statusCode := c.Writer.Status()
if statusCode >= 400 {
// Determine error severity based on status code and error types
severity := determineErrorSeverity(statusCode, c.Errors)
// Create a more descriptive error message based on status code
var errorMsg string
switch {
case statusCode >= 500:
errorMsg = "server error"
case statusCode >= 400:
errorMsg = "client error"
default:
errorMsg = "request failed"
}
// Add error details from Gin's error context if available
if len(c.Errors) > 0 {
for _, err := range c.Errors {
if appErr, ok := err.Err.(*contextutils.AppError); ok {
errorMsg = appErr.Message
severity = string(appErr.Severity)
break
}
errorMsg = err.Error()
}
}
// Record the error with stack trace
span.RecordError(errors.New(errorMsg), trace.WithStackTrace(true))
span.SetStatus(codes.Error, errorMsg)
// Add additional attributes for better debugging
span.SetAttributes(
attribute.Int("http.status_code", statusCode),
attribute.String("http.method", c.Request.Method),
attribute.String("http.path", c.Request.URL.Path),
attribute.String("error.handler", c.HandlerName()),
attribute.String("error.severity", severity),
)
// Add user context if available
session := sessions.Default(c)
if userID, ok := session.Get("user_id").(int); ok {
span.SetAttributes(attribute.Int("error.user_id", userID))
}
// Add request body size for debugging
if c.Request.ContentLength > 0 {
span.SetAttributes(attribute.Int64("error.request_size", c.Request.ContentLength))
}
// Add specific error attributes based on error types
if len(c.Errors) > 0 {
for _, err := range c.Errors {
if appErr, ok := err.Err.(*contextutils.AppError); ok {
span.SetAttributes(
attribute.String("error.code", string(appErr.Code)),
attribute.Bool("error.retryable", contextutils.IsRetryable(appErr)),
)
break
}
}
}
// Add server error specific attributes
if statusCode >= 500 {
span.SetAttributes(
attribute.Bool("error.server_error", true),
)
}
}
}
}
}
// determineErrorSeverity determines the severity level based on status code and error types
func determineErrorSeverity(statusCode int, errors []*gin.Error) string {
// Check for AppError types first
for _, err := range errors {
if appErr, ok := err.Err.(*contextutils.AppError); ok {
return string(appErr.Severity)
}
}
// Fallback to status code based severity
switch {
case statusCode >= 500:
return string(contextutils.SeverityError)
case statusCode >= 400:
return string(contextutils.SeverityWarn)
default:
return string(contextutils.SeverityInfo)
}
}
package observability
import (
"quizapp/internal/config"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/trace"
)
// SetupObservability initializes tracing, metrics, and logging for a service
func SetupObservability(cfg *config.OpenTelemetryConfig, serviceName string) (result0 *trace.TracerProvider, result1 *metric.MeterProvider, result2 *Logger, err error) {
if serviceName != "" {
cfg.ServiceName = serviceName
}
var tp *trace.TracerProvider
var mp *metric.MeterProvider
var logger *Logger
if cfg.EnableTracing {
tp, err = InitTracing(cfg)
if err != nil {
return nil, nil, nil, err
}
// Initialize the global tracer
InitGlobalTracer()
}
if cfg.EnableMetrics {
mp, err = InitMetrics(cfg)
if err != nil {
return tp, nil, nil, err
}
}
if cfg.EnableLogging {
logger = NewLogger(cfg)
} else {
// Return a no-op logger when logging is disabled
logger = NewLogger(&config.OpenTelemetryConfig{EnableLogging: false})
}
return tp, mp, logger, nil
}
package observability
import (
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// FinishSpan ends a span and records any error pointed to by errPtr.
// Use with a named error return: `defer observability.FinishSpan(span, &err)`
func FinishSpan(span trace.Span, errPtr *error) {
if span == nil {
return
}
if errPtr != nil && *errPtr != nil {
span.RecordError(*errPtr, trace.WithStackTrace(true))
span.SetStatus(codes.Error, (*errPtr).Error())
}
span.End()
}
package observability
import (
"context"
"quizapp/internal/config"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
"go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)
// InitTracing initializes OpenTelemetry tracing
func InitTracing(cfg *config.OpenTelemetryConfig) (result0 *trace.TracerProvider, err error) {
ctx := context.Background()
// Set up resource attributes
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName(cfg.ServiceName),
semconv.ServiceVersion(cfg.ServiceVersion),
),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otel resource: %w", err)
}
// Set up exporter
var exporter trace.SpanExporter
switch cfg.Protocol {
case "grpc":
// For gRPC, strip http:// prefix if present, otherwise use endpoint as-is
endpoint := cfg.Endpoint
exp, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(endpoint),
func() otlptracegrpc.Option {
if cfg.Insecure {
return otlptracegrpc.WithInsecure()
}
return nil
}(),
otlptracegrpc.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp grpc exporter: %w", err)
}
exporter = exp
case "http":
exp, err := otlptracehttp.New(ctx,
otlptracehttp.WithEndpoint(cfg.Endpoint),
otlptracehttp.WithInsecure(),
otlptracehttp.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp http exporter: %w", err)
}
exporter = exp
default:
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "unsupported otel protocol: %s", cfg.Protocol)
}
// Set up sampler
sampler := trace.ParentBased(trace.TraceIDRatioBased(cfg.SamplingRate))
// Set up tracer provider
tp := trace.NewTracerProvider(
trace.WithBatcher(exporter),
trace.WithResource(res),
trace.WithSampler(sampler),
)
otel.SetTracerProvider(tp)
// Set up text map propagator for trace context propagation
// This enables the backend to receive and process trace headers from NGINX
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
))
return tp, nil
}
// Package services provides business logic services for the quiz application.
package services
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/xeipuuv/gojsonschema"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// JSON Schema definitions for grammar field
// These schemas are used with the 'grammar' field in OpenAI-compatible API requests
// to enforce specific JSON structure validation. This ensures that AI models return
// exactly the expected format, eliminating parsing errors and improving reliability.
//
// The grammar field is conditionally included based on provider support (see supportsGrammarField).
// Providers that don't support grammar (like Google) will fall back to prompt-based structure guidance.
const (
// Single-item schemas for ai-fix (single question objects)
SingleQuestionSchema = `{
"type": "object",
"properties": {
"question": {"type": "string"},
"options": {"type": "array", "items": {"type": "string"}, "minItems": 4, "maxItems": 4},
"correct_answer": {"type": "integer"},
"explanation": {"type": "string"},
"topic": {"type": "string"}
},
"required": ["question", "options", "correct_answer", "explanation"]
}`
SingleReadingComprehensionSchema = `{
"type": "object",
"properties": {
"passage": {"type": "string"},
"question": {"type": "string"},
"options": {"type": "array", "items": {"type": "string"}, "minItems": 4, "maxItems": 4},
"correct_answer": {"type": "integer"},
"explanation": {"type": "string"},
"topic": {"type": "string"}
},
"required": ["passage", "question", "options", "correct_answer", "explanation"]
}`
SingleVocabularyQuestionSchema = `{
"type": "object",
"properties": {
"sentence": {"type": "string"},
"question": {"type": "string"},
"options": {"type": "array", "items": {"type": "string"}, "minItems": 4, "maxItems": 4},
"correct_answer": {"type": "integer"},
"explanation": {"type": "string"},
"topic": {"type": "string"}
},
"required": ["sentence", "question", "options", "correct_answer", "explanation"]
}`
)
var (
// BatchQuestionsSchema is a batch wrapper around SingleQuestionSchema.
BatchQuestionsSchema = fmt.Sprintf(`{"type":"array","items":%s}`, SingleQuestionSchema)
// BatchReadingComprehensionSchema is a batch wrapper around SingleReadingComprehensionSchema.
BatchReadingComprehensionSchema = fmt.Sprintf(`{"type":"array","items":%s}`, SingleReadingComprehensionSchema)
// BatchVocabularyQuestionSchema is a batch wrapper around SingleVocabularyQuestionSchema.
BatchVocabularyQuestionSchema = fmt.Sprintf(`{"type":"array","items":%s}`, SingleVocabularyQuestionSchema)
)
// UserAIConfig holds per-user AI configuration
type UserAIConfig struct {
Provider string
Model string
APIKey string
Username string // For logging purposes
}
// AIServiceInterface defines the interface for AI-powered question generation
type AIServiceInterface interface {
GenerateQuestion(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest) (*models.Question, error)
GenerateQuestions(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest) ([]*models.Question, error)
GenerateQuestionsStream(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, variety *VarietyElements) error
GenerateChatResponse(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIChatRequest) (string, error)
GenerateChatResponseStream(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIChatRequest, chunks chan<- string) error
GenerateStorySection(ctx context.Context, userConfig *models.UserAIConfig, req *models.StoryGenerationRequest) (string, error)
GenerateStoryQuestions(ctx context.Context, userConfig *models.UserAIConfig, req *models.StoryQuestionsRequest) ([]*models.StorySectionQuestionData, error)
TestConnection(ctx context.Context, provider, model, apiKey string) error
GetConcurrencyStats() ConcurrencyStats
GetQuestionBatchSize(provider string) int
VarietyService() *VarietyService
// TemplateManager exposes template rendering and example loading for prompts
TemplateManager() *AITemplateManager
// SupportsGrammarField reports whether the provider supports the grammar field
SupportsGrammarField(provider string) bool
// CallWithPrompt sends a raw prompt (and optional grammar) to the provider and returns the response
CallWithPrompt(ctx context.Context, userConfig *models.UserAIConfig, prompt, grammar string) (string, error)
Shutdown(ctx context.Context) error
}
// ConcurrencyStats provides metrics about AI request concurrency
type ConcurrencyStats struct {
ActiveRequests int `json:"active_requests"`
MaxConcurrent int `json:"max_concurrent"`
QueuedRequests int `json:"queued_requests"`
TotalRequests int64 `json:"total_requests"`
UserActiveCount map[string]int `json:"user_active_count"`
MaxPerUser int `json:"max_per_user"`
}
// AIService provides AI-powered question generation using OpenAI-compatible APIs
type AIService struct {
httpClient *http.Client
debug bool
cfg *config.Config
// Template management
templateManager *AITemplateManager
// Variety service for question diversity
varietyService *VarietyService
// Usage stats service for tracking token usage
usageStatsSvc UsageStatsServiceInterface
// Concurrency control
globalSemaphore chan struct{} // Limits total concurrent requests
maxConcurrent int // Maximum concurrent requests globally
maxPerUser int // Maximum concurrent requests per user
// Per-user concurrency tracking
userRequestCount map[string]int // Username -> active request count
concurrencyMu sync.RWMutex // Protects user maps
// Metrics
totalRequests int64 // Total requests processed
activeRequests int // Current active requests
statsMu sync.RWMutex // Protects stats
// Observability
logger *observability.Logger
// Shutdown control
shutdownCtx context.Context
shutdownMu sync.RWMutex
}
// Schema validation counters
var (
SchemaValidationFailures = make(map[models.QuestionType]int)
SchemaValidationFailureDetails = make(map[models.QuestionType][]string) // NEW: error details
SchemaValidationMu sync.Mutex
)
// extractItemsSchema extracts the items schema from a batch schema
func extractItemsSchema(batchSchema string) (result0 string, err error) {
var schemaMap map[string]interface{}
if err = json.Unmarshal([]byte(batchSchema), &schemaMap); err != nil {
return "", err
}
// For batch schemas, extract the items schema
if items, ok := schemaMap["items"]; ok {
var itemsBytes []byte
itemsBytes, err = json.Marshal(items)
if err != nil {
return "", err
}
return string(itemsBytes), nil
}
return "", contextutils.ErrorWithContextf("no items found in batch schema")
}
// ValidateQuestionSchema validates a question against the appropriate schema
func (s *AIService) ValidateQuestionSchema(ctx context.Context, qType models.QuestionType, question interface{}) (result0 bool, err error) {
_, span := observability.TraceAIFunction(ctx, "validate_question_schema",
observability.AttributeQuestionType(qType),
)
defer observability.FinishSpan(span, &err)
// Validate input parameters
if question == nil {
span.SetAttributes(attribute.String("validation.result", "nil_question"))
return false, contextutils.ErrorWithContextf("question cannot be nil")
}
var schema string
switch qType {
case models.Vocabulary:
schema = BatchVocabularyQuestionSchema
case models.ReadingComprehension:
schema = BatchReadingComprehensionSchema
case models.FillInBlank, models.QuestionAnswer:
schema = BatchQuestionsSchema
default:
span.SetAttributes(attribute.String("validation.result", "unknown_type"))
return false, contextutils.ErrorWithContextf("unknown question type: %v", qType)
}
// Extract the items schema for validation
itemSchema, err := extractItemsSchema(schema)
if err != nil {
span.SetAttributes(attribute.String("validation.result", "schema_extract_error"), attribute.String("validation.error", err.Error()))
return false, contextutils.WrapErrorf(err, "failed to extract schema for question type %v", qType)
}
// Marshal the question to JSON
// If question is a *models.Question, validate only Content
toValidate := question
if q, ok := question.(*models.Question); ok {
if q == nil {
span.SetAttributes(attribute.String("validation.result", "nil_question_model"))
return false, contextutils.ErrorWithContextf("question model is nil")
}
toValidate = q.Content
}
questionBytes, err := json.Marshal(toValidate)
if err != nil {
span.SetAttributes(attribute.String("validation.result", "marshal_error"), attribute.String("validation.error", err.Error()))
return false, contextutils.WrapErrorf(err, "failed to marshal question for validation")
}
// Validate
result, err := gojsonschema.Validate(
gojsonschema.NewStringLoader(itemSchema),
gojsonschema.NewBytesLoader(questionBytes),
)
if err != nil {
span.SetAttributes(attribute.String("validation.result", "validate_error"), attribute.String("validation.error", err.Error()))
return false, contextutils.WrapErrorf(err, "schema validation failed for question type %v", qType)
}
if !result.Valid() {
errs := result.Errors()
var errorMessages []string
for _, e := range errs {
errorMessages = append(errorMessages, e.String())
}
span.SetAttributes(attribute.String("validation.result", "invalid"))
return false, contextutils.ErrorWithContextf("question failed schema validation: %s", strings.Join(errorMessages, "; "))
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, nil
}
// NewAIService creates a new AI service instance
func NewAIService(cfg *config.Config, logger *observability.Logger, usageStatsSvc UsageStatsServiceInterface) *AIService {
// Validate required dependencies
if usageStatsSvc == nil {
panic("usageStatsSvc is required for AI service")
}
// Create template manager
templateManager, err := NewAITemplateManager()
if err != nil {
logger.Error(context.Background(), "Failed to create template manager", err, map[string]interface{}{})
panic(err) // Use panic for fatal errors in initialization
}
// Create variety service
varietyService := NewVarietyServiceWithLogger(cfg, logger)
// Create instrumented HTTP client with reasonable timeouts and explicit span options
// Use a timeout slightly less than AIRequestTimeout to allow context cancellation
httpClient := &http.Client{
Timeout: config.AIRequestTimeout - 5*time.Second, // Slightly less than AIRequestTimeout
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
}
// Get concurrency limits from config
maxConcurrent := cfg.Server.MaxAIConcurrent
maxPerUser := cfg.Server.MaxAIPerUser
// Create global semaphore for limiting concurrent requests
globalSemaphore := make(chan struct{}, maxConcurrent)
service := &AIService{
httpClient: httpClient,
debug: cfg.Server.Debug,
cfg: cfg,
templateManager: templateManager,
varietyService: varietyService,
usageStatsSvc: usageStatsSvc,
globalSemaphore: globalSemaphore,
maxConcurrent: maxConcurrent,
maxPerUser: maxPerUser,
userRequestCount: make(map[string]int),
shutdownCtx: context.Background(),
logger: logger,
}
return service
}
// Shutdown gracefully shuts down the AI service and cleans up resources
func (s *AIService) Shutdown(ctx context.Context) error {
s.shutdownMu.Lock()
defer s.shutdownMu.Unlock()
// Create a new shutdown context
shutdownCtx, cancel := context.WithCancel(ctx)
s.shutdownCtx = shutdownCtx
defer cancel()
// Wait for all active requests to complete with timeout
timeout := config.AIShutdownTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
// Wait for active requests to complete
ticker := time.NewTicker(config.AIShutdownPollInterval)
defer ticker.Stop()
for i := 0; i < int(timeout/config.AIShutdownPollInterval); i++ {
s.statsMu.RLock()
active := s.activeRequests
s.statsMu.RUnlock()
if active == 0 {
break
}
select {
case <-ticker.C:
continue
case <-ctx.Done():
return ctx.Err()
}
}
// Close the HTTP client
if s.httpClient != nil {
s.httpClient.CloseIdleConnections()
}
// Clean up user request counts
s.concurrencyMu.Lock()
s.userRequestCount = make(map[string]int)
s.concurrencyMu.Unlock()
s.logger.Info(ctx, "AI Service shutdown completed")
return nil
}
// isShutdown checks if the service is shutting down
func (s *AIService) isShutdown() bool {
s.shutdownMu.RLock()
defer s.shutdownMu.RUnlock()
select {
case <-s.shutdownCtx.Done():
return true
default:
return false
}
}
// OpenAIRequest represents a request to the OpenAI-compatible API
type OpenAIRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
Grammar string `json:"grammar,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// Message represents a chat message in the API request
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// OpenAIResponse represents a response from the OpenAI-compatible API
type OpenAIResponse struct {
Choices []Choice `json:"choices"`
Error *APIError `json:"error,omitempty"`
Usage *Usage `json:"usage,omitempty"`
}
// Choice represents a choice in the API response
type Choice struct {
Message Message `json:"message"`
}
// APIError represents an error response from the API
type APIError struct {
Message string `json:"message"`
Type string `json:"type"`
}
// Usage represents token usage information from OpenAI API
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// OpenAIStreamResponse represents a streaming response chunk from the OpenAI-compatible API
type OpenAIStreamResponse struct {
Choices []StreamChoice `json:"choices"`
Error *APIError `json:"error,omitempty"`
Usage *Usage `json:"usage,omitempty"`
}
// StreamChoice represents a choice in the streaming API response
type StreamChoice struct {
Delta StreamDelta `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
// StreamDelta represents the delta content in a streaming response
type StreamDelta struct {
Content string `json:"content"`
}
// getGrammarSchema returns the appropriate JSON schema for the given question type
func getGrammarSchema(questionType models.QuestionType) string {
// Always return the batch schema for each type
switch questionType {
case models.ReadingComprehension:
return BatchReadingComprehensionSchema
case models.Vocabulary:
return BatchVocabularyQuestionSchema
case models.FillInBlank:
return BatchQuestionsSchema
case models.QuestionAnswer:
return BatchQuestionsSchema
}
// Fallback for unknown types
return BatchQuestionsSchema
}
// GetFixSchema returns the single-item JSON schema for ai-fix or an error if unsupported.
func GetFixSchema(questionType models.QuestionType) (string, error) {
switch questionType {
case models.ReadingComprehension:
return SingleReadingComprehensionSchema, nil
case models.Vocabulary:
return SingleVocabularyQuestionSchema, nil
case models.FillInBlank, models.QuestionAnswer:
return SingleQuestionSchema, nil
default:
return "", contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "no schema for question type: %v", questionType)
}
}
// addJSONStructureGuidance appends JSON structure requirements to prompts for providers that don't support grammar
func (s *AIService) addJSONStructureGuidance(prompt string, questionType models.QuestionType) string {
// Get the schema for this question type
schema := getGrammarSchema(questionType)
data := AITemplateData{
SchemaForPrompt: schema,
}
guidance, err := s.templateManager.RenderTemplate(JSONStructureGuidanceTemplate, data)
if err != nil {
s.logger.Error(context.Background(), "Failed to render JSON structure guidance template", err, map[string]interface{}{})
panic(err)
}
return prompt + guidance
}
// GenerateQuestion generates a single question using AI
func (s *AIService) GenerateQuestion(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest) (result0 *models.Question, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_question",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(string(req.QuestionType)),
)
defer observability.FinishSpan(span, &err)
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
var prompt string
var grammar string
if supportsGrammar {
// Use batch prompt with count=1 for single question
prompt = s.buildBatchQuestionPrompt(ctx, req, nil)
grammar = getGrammarSchema(req.QuestionType)
} else {
// Use batch prompt with JSON structure guidance embedded
prompt = s.buildBatchQuestionPromptWithJSONStructure(ctx, req, nil)
grammar = "" // No grammar field for providers that don't support it
}
response, err := s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
return nil, err
}
question, err := s.parseQuestionResponse(ctx, response, req.Language, req.Level, req.QuestionType, userConfig.Provider)
if err != nil {
return nil, err
}
return question, nil
}
// GenerateQuestions generates multiple questions in a single batch request
func (s *AIService) GenerateQuestions(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest) (result0 []*models.Question, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(string(req.QuestionType)),
observability.AttributeLimit(req.Count),
)
defer observability.FinishSpan(span, &err)
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
var prompt string
var grammar string
if supportsGrammar {
// Use regular prompt with grammar field
prompt = s.buildBatchQuestionPrompt(ctx, req, nil)
grammar = getGrammarSchema(req.QuestionType)
} else {
// Use prompt with JSON structure guidance embedded
prompt = s.buildBatchQuestionPromptWithJSONStructure(ctx, req, nil)
grammar = "" // No grammar field for providers that don't support it
}
response, err := s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
return nil, err
}
questions, err := s.parseQuestionsResponse(ctx, response, req.Language, req.Level, req.QuestionType, userConfig.Provider)
if err != nil {
return nil, err
}
return questions, nil
}
// GenerateQuestionsStream generates questions and streams them via a channel, using the provided variety elements
func (s *AIService) GenerateQuestionsStream(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, variety *VarietyElements) (err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions_stream",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(string(req.QuestionType)),
observability.AttributeLimit(req.Count),
)
defer observability.FinishSpan(span, &err)
defer close(progress)
return s.withConcurrencyControl(ctx, userConfig.Username, func() error {
// Get the batch size for this provider
batchSize := s.getQuestionBatchSize(userConfig.Provider)
// Use batch generation for multiple questions
return s.generateQuestionsInBatchesWithVariety(ctx, userConfig, req, progress, batchSize, variety)
})
}
// generateQuestionsInBatchesWithVariety generates questions in batches for efficiency, using the provided variety elements
func (s *AIService) generateQuestionsInBatchesWithVariety(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, batchSize int, variety *VarietyElements) (err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions_in_batches_with_variety",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(req.QuestionType),
observability.AttributeLanguage(req.Language),
observability.AttributeLevel(req.Level),
attribute.Int("batch_size", batchSize),
attribute.Int("total_count", req.Count),
attribute.Bool("variety.enabled", variety != nil),
)
defer observability.FinishSpan(span, &err)
// Local copy of history to be updated as we generate questions
localHistory := make([]string, len(req.RecentQuestionHistory))
copy(localHistory, req.RecentQuestionHistory)
remaining := req.Count
generated := 0
for remaining > 0 {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Calculate how many questions to generate in this batch
currentBatchSize := min(remaining, batchSize)
// Create a batch request
batchReq := &models.AIQuestionGenRequest{
Language: req.Language,
Level: req.Level,
QuestionType: req.QuestionType,
Count: currentBatchSize,
RecentQuestionHistory: localHistory,
}
// Generate questions in batch using the provided variety elements
questions, err := s.generateQuestionsWithVariety(ctx, userConfig, batchReq, variety)
if err != nil {
return contextutils.WrapErrorf(err, "failed to generate batch of %d questions for user %s", currentBatchSize, userConfig.Username)
}
// Stream the generated questions
for _, question := range questions {
// Add generated question content to history for next iterations
if qContent, ok := question.Content["question"]; ok {
if qStr, ok := qContent.(string); ok {
localHistory = append(localHistory, qStr)
}
}
progress <- question
generated++
}
remaining -= len(questions)
}
return nil
}
// generateQuestionsWithVariety generates a batch of questions using the provided variety elements
func (s *AIService) generateQuestionsWithVariety(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIQuestionGenRequest, variety *VarietyElements) (result0 []*models.Question, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions_with_variety",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(req.QuestionType),
observability.AttributeLanguage(req.Language),
observability.AttributeLevel(req.Level),
attribute.Int("count", req.Count),
attribute.Bool("variety.enabled", variety != nil),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
var prompt string
var grammar string
if supportsGrammar {
prompt = s.buildBatchQuestionPrompt(ctx, req, variety)
grammar = getGrammarSchema(req.QuestionType)
} else {
prompt = s.buildBatchQuestionPromptWithJSONStructure(ctx, req, variety)
grammar = ""
}
response, err := s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
return nil, err
}
questions, err := s.parseQuestionsResponse(ctx, response, req.Language, req.Level, req.QuestionType, userConfig.Provider)
if err != nil {
return nil, err
}
return questions, nil
}
// GenerateChatResponse generates a chat response using AI
func (s *AIService) GenerateChatResponse(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIChatRequest) (result0 string, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_chat_response",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
)
defer observability.FinishSpan(span, &err)
var result string
var resultErr error
err = s.withConcurrencyControl(ctx, userConfig.Username, func() error {
prompt := s.buildChatPrompt(req)
// No grammar constraint for open-ended chat
result, resultErr = s.callOpenAI(ctx, userConfig, prompt, "")
return resultErr
})
if err != nil {
return "", err
}
return result, resultErr
}
// GenerateChatResponseStream generates a streaming chat response using AI
func (s *AIService) GenerateChatResponseStream(ctx context.Context, userConfig *models.UserAIConfig, req *models.AIChatRequest, chunks chan<- string) (err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_chat_response_stream",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
)
defer observability.FinishSpan(span, &err)
// Don't close the channel here - let the caller handle it to avoid race conditions
return s.withConcurrencyControl(ctx, userConfig.Username, func() error {
prompt := s.buildChatPrompt(req)
// No grammar constraint for open-ended chat
return s.callOpenAIStream(ctx, userConfig, prompt, "", chunks)
})
}
// GenerateStorySection generates a story section using AI
func (s *AIService) GenerateStorySection(ctx context.Context, userConfig *models.UserAIConfig, req *models.StoryGenerationRequest) (result string, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_story_section",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("story.title", req.Title),
attribute.String("story.language", req.Language),
attribute.String("story.level", req.Level),
attribute.Bool("story.is_first_section", req.IsFirstSection),
)
defer observability.FinishSpan(span, &err)
var storyResult string
var storyErr error
err = s.withConcurrencyControl(ctx, userConfig.Username, func() error {
prompt := s.buildStorySectionPrompt(req)
storyResult, storyErr = s.callOpenAIWithRetry(ctx, userConfig, prompt, "")
return storyErr
})
if err != nil {
return "", err
}
return storyResult, storyErr
}
// GenerateStoryQuestions generates comprehension questions for a story section
func (s *AIService) GenerateStoryQuestions(ctx context.Context, userConfig *models.UserAIConfig, req *models.StoryQuestionsRequest) (result []*models.StorySectionQuestionData, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_story_questions",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("story.language", req.Language),
attribute.String("story.level", req.Level),
attribute.Int("questions.count", req.QuestionCount),
)
defer observability.FinishSpan(span, &err)
var questionsResult []*models.StorySectionQuestionData
var questionsErr error
err = s.withConcurrencyControl(ctx, userConfig.Username, func() error {
prompt := s.buildStoryQuestionsPrompt(req)
response, responseErr := s.callOpenAI(ctx, userConfig, prompt, "")
if responseErr != nil {
return responseErr
}
// Parse the JSON response into question data
questionsResult, questionsErr = s.parseStoryQuestionsResponse(response)
if questionsErr != nil {
return contextutils.WrapErrorf(questionsErr, "failed to parse story questions response")
}
return nil
})
if err != nil {
return nil, err
}
return questionsResult, questionsErr
}
// stringPtrToString converts a *string to string, returning empty string if nil
func stringPtrToString(ptr *string) string {
if ptr == nil {
return ""
}
return *ptr
}
// buildStorySectionPrompt builds the prompt for story section generation
func (s *AIService) buildStorySectionPrompt(req *models.StoryGenerationRequest) string {
// Create template data from the request
templateData := AITemplateData{
Language: req.Language,
Level: req.Level,
Title: req.Title,
Subject: stringPtrToString(req.Subject),
AuthorStyle: stringPtrToString(req.AuthorStyle),
TimePeriod: stringPtrToString(req.TimePeriod),
Genre: stringPtrToString(req.Genre),
Tone: stringPtrToString(req.Tone),
CharacterNames: stringPtrToString(req.CharacterNames),
CustomInstructions: stringPtrToString(req.CustomInstructions),
TargetWords: req.TargetWords,
TargetSentences: req.TargetSentences,
IsFirstSection: req.IsFirstSection,
PreviousSections: req.PreviousSections,
}
template, err := s.templateManager.RenderTemplate("story_section_prompt.tmpl", templateData)
if err != nil {
// No fallback - error out if template not found
panic(contextutils.WrapErrorf(err, "failed to render story section template"))
}
return template
}
// buildStoryQuestionsPrompt builds the prompt for story questions generation
func (s *AIService) buildStoryQuestionsPrompt(req *models.StoryQuestionsRequest) string {
// Create template data from the request
templateData := AITemplateData{
Language: req.Language,
Level: req.Level,
Count: req.QuestionCount,
SectionText: req.SectionText,
}
template, err := s.templateManager.RenderTemplate("story_questions_prompt.tmpl", templateData)
if err != nil {
// No fallback - error out if template not found
panic(contextutils.WrapErrorf(err, "failed to render story questions template"))
}
return template
}
// parseStoryQuestionsResponse parses the AI response into question data
func (s *AIService) parseStoryQuestionsResponse(response string) ([]*models.StorySectionQuestionData, error) {
// Clean the response (remove markdown code blocks if present)
response = strings.TrimSpace(response)
if strings.HasPrefix(response, "```json") {
response = strings.TrimPrefix(response, "```json")
response = strings.TrimSuffix(response, "```")
response = strings.TrimSpace(response)
}
var questions []*models.StorySectionQuestionData
if err := json.Unmarshal([]byte(response), &questions); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to unmarshal questions JSON")
}
// Validate the questions
for i, q := range questions {
if q.QuestionText == "" {
return nil, contextutils.ErrorWithContextf("question %d: missing question text", i)
}
if len(q.Options) != 4 {
return nil, contextutils.ErrorWithContextf("question %d: must have exactly 4 options, got %d", i, len(q.Options))
}
if q.CorrectAnswerIndex < 0 || q.CorrectAnswerIndex >= 4 {
return nil, contextutils.ErrorWithContextf("question %d: correct_answer_index must be 0-3, got %d", i, q.CorrectAnswerIndex)
}
}
return questions, nil
}
// TestConnection tests the connection to the AI service
func (s *AIService) TestConnection(ctx context.Context, provider, model, apiKey string) (err error) {
_, span := observability.TraceAIFunction(ctx, "test_connection",
attribute.String("ai.provider", provider),
attribute.String("ai.model", model),
)
defer observability.FinishSpan(span, &err)
// Validate input parameters
if provider == "" {
span.SetAttributes(attribute.String("test.result", "empty_provider"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "provider is required for testing connection")
}
if model == "" {
span.SetAttributes(attribute.String("test.result", "empty_model"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "model is required for testing connection")
}
s.logger.Debug(ctx, "TestConnection called", map[string]interface{}{
"provider": provider,
"model": model,
"apiKey": contextutils.MaskAPIKey(apiKey),
})
// Require API key for all providers that are not Ollama
if provider != "ollama" && apiKey == "" {
span.SetAttributes(attribute.String("test.result", "missing_api_key"), attribute.String("provider", provider))
return contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "API key is required for testing connection with provider '%s'", provider)
}
// Create a simple test configuration
userConfig := &models.UserAIConfig{
Provider: provider,
Model: model,
APIKey: apiKey,
Username: "test-user",
}
s.logger.Debug(ctx, "Created userConfig", map[string]interface{}{
"provider": userConfig.Provider,
"model": userConfig.Model,
})
// Create a simple test request
testPrompt := "Respond with exactly the word 'SUCCESS' and nothing else."
// Create a timeout context for the test
testCtx, cancel := context.WithTimeout(ctx, config.AIRequestTimeout)
defer cancel()
// Test the actual AI service call
response, err := s.callOpenAI(testCtx, userConfig, testPrompt, "")
if err != nil {
span.SetAttributes(attribute.String("test.result", "call_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(err, "connection test failed for provider '%s' with model '%s'", provider, model)
}
// Check if we got a reasonable response
if response == "" {
span.SetAttributes(attribute.String("test.result", "empty_response"))
return contextutils.WrapError(contextutils.ErrAIResponseInvalid, "connection test failed: received empty response from AI service")
}
// Validate that the response contains something meaningful
if len(response) < 3 {
span.SetAttributes(attribute.String("test.result", "response_too_short"), attribute.Int("response_length", len(response)))
return contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "connection test failed: response too short (%d characters)", len(response))
}
// The response should contain something meaningful
s.logger.Info(ctx, "TestConnection successful", map[string]interface{}{
"provider": provider,
"response_length": len(response),
})
span.SetAttributes(attribute.String("test.result", "success"), attribute.Int("response_length", len(response)))
return nil
}
// buildBatchQuestionPromptWithJSONStructure now takes variety elements
func (s *AIService) buildBatchQuestionPromptWithJSONStructure(ctx context.Context, req *models.AIQuestionGenRequest, variety *VarietyElements) string {
prompt := s.buildBatchQuestionPrompt(ctx, req, variety)
return s.addJSONStructureGuidance(prompt, req.QuestionType)
}
// buildBatchQuestionPrompt now takes variety elements
func (s *AIService) buildBatchQuestionPrompt(ctx context.Context, req *models.AIQuestionGenRequest, variety *VarietyElements) string {
_, span := observability.TraceAIFunction(ctx, "build_batch_question_prompt",
observability.AttributeQuestionType(req.QuestionType),
observability.AttributeLanguage(req.Language),
observability.AttributeLevel(req.Level),
attribute.Int("count", req.Count),
attribute.Bool("variety.enabled", variety != nil),
)
defer span.End()
tmplData := AITemplateData{
SchemaForPrompt: getGrammarSchema(req.QuestionType),
Language: req.Language,
Level: req.Level,
QuestionType: string(req.QuestionType),
Count: req.Count,
RecentQuestionHistory: req.RecentQuestionHistory,
}
if variety != nil {
tmplData.TopicCategory = variety.TopicCategory
tmplData.GrammarFocus = variety.GrammarFocus
tmplData.VocabularyDomain = variety.VocabularyDomain
tmplData.Scenario = variety.Scenario
tmplData.StyleModifier = variety.StyleModifier
tmplData.DifficultyModifier = variety.DifficultyModifier
tmplData.TimeContext = variety.TimeContext
}
// Priority data is handled by the worker, not passed to AI service
// Load example for this question type
if exampleContent, err := s.templateManager.LoadExample(string(req.QuestionType)); err == nil {
tmplData.ExampleContent = exampleContent
}
prompt, err := s.templateManager.RenderTemplate(BatchQuestionPromptTemplate, tmplData)
if err != nil {
s.logger.Error(ctx, "Failed to render batch question prompt template", err, map[string]interface{}{})
panic(err) // Use panic for fatal errors in template rendering
}
return prompt
}
func (s *AIService) buildChatPrompt(req *models.AIChatRequest) string {
// Convert conversation history to template format
var conversationHistory []ChatMessage
for _, msg := range req.ConversationHistory {
conversationHistory = append(conversationHistory, ChatMessage{
Role: string(msg.Role),
Content: msg.Content,
})
}
data := AITemplateData{
Language: req.Language,
Level: req.Level,
QuestionType: string(req.QuestionType),
Passage: req.Passage,
Question: req.Question,
Options: req.Options,
IsCorrect: req.IsCorrect,
ConversationHistory: conversationHistory,
UserMessage: req.UserMessage,
}
prompt, err := s.templateManager.RenderTemplate(ChatPromptTemplate, data)
if err != nil {
s.logger.Error(context.Background(), "Failed to render chat prompt template", err, map[string]interface{}{})
panic(err) // Use panic for fatal errors in template rendering
}
return prompt
}
// getMaxTokensForModel looks up the max_tokens for a specific provider and model
func (s *AIService) getMaxTokensForModel(provider, model string) int {
// Look up the model in the provider configuration
if s.cfg.Providers != nil {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == provider {
for _, modelConfig := range providerConfig.Models {
if modelConfig.Code == model {
if modelConfig.MaxTokens > 0 {
return modelConfig.MaxTokens
}
break
}
}
break
}
}
}
// Default fallback
return 4000
}
// callOpenAI makes a request to the OpenAI-compatible API
func (s *AIService) callOpenAI(ctx context.Context, userConfig *models.UserAIConfig, prompt, grammar string) (result0 string, err error) {
if userConfig == nil {
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "userConfig is required")
}
ctx, span := observability.TraceAIFunction(ctx, "call_openai",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("ai.username", userConfig.Username),
attribute.Int("prompt.length", len(prompt)),
attribute.Bool("grammar.enabled", grammar != ""),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Validate input parameters
if userConfig.Provider == "" {
span.SetAttributes(attribute.String("call.result", "empty_provider"))
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "provider is required")
}
if userConfig.Model == "" {
span.SetAttributes(attribute.String("call.result", "empty_model"))
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "model is required")
}
if prompt == "" {
span.SetAttributes(attribute.String("call.result", "empty_prompt"))
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "prompt cannot be empty")
}
apiURL := ""
model := userConfig.Model
apiKey := userConfig.APIKey
// Look up the default URL from provider config
if s.cfg.Providers != nil {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == userConfig.Provider && providerConfig.URL != "" {
apiURL = providerConfig.URL
break
}
}
}
if apiURL == "" {
span.SetAttributes(attribute.String("call.result", "no_url_configured"), attribute.String("provider", userConfig.Provider))
return "", contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "no base URL configured for provider '%s'", userConfig.Provider)
}
userPrefix := ""
if userConfig.Username != "" {
userPrefix = fmt.Sprintf("[user=%s] ", userConfig.Username)
}
s.logger.Debug(ctx, "Starting AI request", map[string]interface{}{
"user_prefix": userPrefix,
"url": apiURL + "/chat/completions",
"model": model,
"provider": userConfig.Provider,
})
// Create messages with just the user prompt - grammar field will enforce JSON structure
messages := []Message{{Role: "user", Content: prompt}}
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
reqBody := OpenAIRequest{
Model: model,
Messages: messages,
Temperature: 0.7,
MaxTokens: s.getMaxTokensForModel(userConfig.Provider, userConfig.Model),
}
// Only include grammar field if the provider supports it
if supportsGrammar && grammar != "" {
reqBody.Grammar = grammar
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
s.logger.Error(ctx, "Failed to marshal AI request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("call.result", "marshal_failed"), attribute.String("error", err.Error()))
return "", contextutils.WrapErrorf(err, "failed to marshal request body")
}
s.logger.Debug(ctx, "Making AI HTTP request", map[string]interface{}{
"user_prefix": userPrefix,
"url": apiURL + "/chat/completions",
})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
s.logger.Error(ctx, "Failed to create AI HTTP request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("call.result", "request_creation_failed"), attribute.String("error", err.Error()))
return "", contextutils.WrapErrorf(err, "failed to create HTTP request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "quizapp/1.0")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
s.logger.Debug(ctx, "Using API key authentication", map[string]interface{}{
"user_prefix": userPrefix,
})
} else {
s.logger.Debug(ctx, "No API key provided, using anonymous access", map[string]interface{}{
"user_prefix": userPrefix,
})
}
startTime := time.Now()
resp, err := s.httpClient.Do(req.WithContext(ctx))
duration := time.Since(startTime)
if err != nil {
s.logger.Error(ctx, "AI HTTP request failed", err, map[string]interface{}{
"user_prefix": userPrefix,
"duration": duration.String(),
})
span.SetAttributes(attribute.String("call.result", "http_request_failed"), attribute.String("error", err.Error()), attribute.String("duration", duration.String()))
return "", contextutils.WrapErrorf(err, "HTTP request failed after %v", duration)
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{
"error": err.Error(),
})
}
}()
s.logger.Info(ctx, "AI Service HTTP request completed", map[string]interface{}{
"user_prefix": userPrefix,
"duration": duration.String(),
"status_code": resp.StatusCode,
})
body, err := io.ReadAll(resp.Body)
if err != nil {
span.SetAttributes(attribute.String("call.result", "body_read_failed"), attribute.String("error", err.Error()))
return "", contextutils.WrapErrorf(err, "failed to read response body")
}
if resp.StatusCode != http.StatusOK {
span.SetAttributes(attribute.String("call.result", "http_error"), attribute.Int("status_code", resp.StatusCode), attribute.String("body", string(body)))
// Handle rate limit errors specifically
if resp.StatusCode == http.StatusTooManyRequests {
return "", contextutils.WrapErrorf(contextutils.ErrRateLimit, "Rate limit exceeded for AI provider %s: %s", userConfig.Provider, string(body))
}
return "", contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "API request failed with status %d to %s: %s", resp.StatusCode, apiURL+"/chat/completions", string(body))
}
var openAIResp OpenAIResponse
if err := json.Unmarshal(body, &openAIResp); err != nil {
span.SetAttributes(attribute.String("call.result", "json_unmarshal_failed"), attribute.String("error", err.Error()), attribute.String("body", string(body)))
return "", contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to parse AI response as JSON: %w. Raw Response: %s", err, string(body))
}
if openAIResp.Error != nil {
span.SetAttributes(attribute.String("call.result", "api_error"), attribute.String("error_message", openAIResp.Error.Message), attribute.String("error_type", openAIResp.Error.Type))
return "", contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "OpenAI API error: %s", openAIResp.Error.Message)
}
if len(openAIResp.Choices) == 0 {
span.SetAttributes(attribute.String("call.result", "no_choices"))
return "", contextutils.WrapError(contextutils.ErrAIResponseInvalid, "no response from OpenAI")
}
content := openAIResp.Choices[0].Message.Content
if content == "" {
span.SetAttributes(attribute.String("call.result", "empty_content"))
s.logger.Warn(ctx, "OpenAI returned empty content", map[string]interface{}{
"user_prefix": userPrefix,
"response": string(body),
"prompt_length": len(prompt),
})
return "", contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI returned empty content")
}
span.SetAttributes(attribute.String("call.result", "success"), attribute.Int("content_length", len(content)), attribute.String("duration", duration.String()))
// Extract usage information if available and track it internally
if openAIResp.Usage != nil {
userID := contextutils.GetUserIDFromContext(ctx)
apiKeyID := contextutils.GetAPIKeyIDFromContext(ctx)
s.trackAIUsage(ctx, userConfig, *openAIResp.Usage, userID, apiKeyID)
} else {
s.logger.Warn(ctx, "No usage information available", map[string]any{
"user_prefix": userPrefix,
"response": string(body),
"prompt_length": len(prompt),
})
span.SetAttributes(attribute.String("call.result", "no_usage_information"), attribute.String("response", string(body)), attribute.String("prompt_length", strconv.Itoa(len(prompt))))
}
return content, nil
}
// callOpenAIWithRetry attempts to call OpenAI with retry logic for empty content responses
func (s *AIService) callOpenAIWithRetry(ctx context.Context, userConfig *models.UserAIConfig, prompt, grammar string) (result string, err error) {
_, span := observability.TraceAIFunction(ctx, "call_openai_with_retry",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("ai.username", userConfig.Username),
attribute.Int("prompt.length", len(prompt)),
attribute.Bool("grammar.enabled", grammar != ""),
)
defer observability.FinishSpan(span, &err)
const maxRetries = 2
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// Add a small delay between retries
time.Sleep(time.Duration(attempt) * time.Second)
}
result, err = s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
// If it's not an empty content error, don't retry
if !contextutils.IsError(err, contextutils.ErrAIResponseInvalid) {
return result, err
}
lastErr = err
// If this is the last attempt, return the error
if attempt == maxRetries {
break
}
s.logger.Warn(ctx, "Retrying AI request due to empty content", map[string]interface{}{
"attempt": attempt + 1,
"max_retries": maxRetries,
"error": err.Error(),
})
continue
}
return result, nil
}
return result, contextutils.WrapErrorf(lastErr, "AI returned empty content after %d attempts", maxRetries+1)
}
// callOpenAIStream makes a streaming request to the OpenAI-compatible API
func (s *AIService) callOpenAIStream(ctx context.Context, userConfig *models.UserAIConfig, prompt, grammar string, chunks chan<- string) (err error) {
if userConfig == nil {
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "userConfig is required")
}
_, span := observability.TraceAIFunction(ctx, "call_openai_stream",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("ai.username", userConfig.Username),
attribute.Int("prompt.length", len(prompt)),
attribute.Bool("grammar.enabled", grammar != ""),
)
defer observability.FinishSpan(span, &err)
// Validate input parameters
if userConfig.Provider == "" {
span.SetAttributes(attribute.String("stream.result", "empty_provider"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "provider is required")
}
if userConfig.Model == "" {
span.SetAttributes(attribute.String("stream.result", "empty_model"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "model is required")
}
if prompt == "" {
span.SetAttributes(attribute.String("stream.result", "empty_prompt"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "prompt cannot be empty")
}
if chunks == nil {
span.SetAttributes(attribute.String("stream.result", "nil_chunks_channel"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "chunks channel is required")
}
apiURL := ""
model := userConfig.Model
apiKey := userConfig.APIKey
// Look up the default URL from provider config
if s.cfg.Providers != nil {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == userConfig.Provider && providerConfig.URL != "" {
apiURL = providerConfig.URL
break
}
}
}
if apiURL == "" {
span.SetAttributes(attribute.String("stream.result", "no_url_configured"), attribute.String("provider", userConfig.Provider))
return contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "no base URL configured for provider '%s'", userConfig.Provider)
}
userPrefix := ""
if userConfig.Username != "" {
userPrefix = fmt.Sprintf("[user=%s] ", userConfig.Username)
}
s.logger.Info(ctx, "AI Service Starting streaming request", map[string]interface{}{
"user_prefix": userPrefix,
"api_url": apiURL + "/chat/completions",
"model": model,
"provider": userConfig.Provider,
})
// Create messages with just the user prompt - grammar field will enforce JSON structure
messages := []Message{{Role: "user", Content: prompt}}
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
reqBody := OpenAIRequest{
Model: model,
Messages: messages,
Temperature: 0.7,
MaxTokens: s.getMaxTokensForModel(userConfig.Provider, userConfig.Model),
Stream: true, // Enable streaming
}
// Only include grammar field if the provider supports it
if supportsGrammar && grammar != "" {
reqBody.Grammar = grammar
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
s.logger.Error(ctx, "Failed to marshal request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("stream.result", "marshal_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(err, "failed to marshal streaming request body")
}
s.logger.Info(ctx, "AI Service Making streaming HTTP request", map[string]interface{}{
"user_prefix": userPrefix,
"api_url": apiURL + "/chat/completions",
})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
s.logger.Error(ctx, "Failed to create HTTP request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("stream.result", "request_creation_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(err, "failed to create streaming HTTP request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("User-Agent", "quizapp/1.0")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
s.logger.Info(ctx, "AI Service Using API key authentication", map[string]interface{}{
"user_prefix": userPrefix,
})
} else {
s.logger.Info(ctx, "AI Service No API key provided, using anonymous access", map[string]interface{}{
"user_prefix": userPrefix,
})
}
startTime := time.Now()
resp, err := s.httpClient.Do(req.WithContext(ctx))
if err != nil {
s.logger.Error(ctx, "HTTP request failed", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("stream.result", "http_request_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "http client error: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{
"error": err.Error(),
})
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
span.SetAttributes(attribute.String("stream.result", "http_error"), attribute.Int("status_code", resp.StatusCode), attribute.String("body", string(body)))
// Handle rate limit errors specifically
if resp.StatusCode == http.StatusTooManyRequests {
return contextutils.WrapErrorf(contextutils.ErrRateLimit, "Rate limit exceeded for AI provider %s: %s", userConfig.Provider, string(body))
}
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "API request failed with status %d to %s: %s", resp.StatusCode, apiURL+"/chat/completions", string(body))
}
s.logger.Info(ctx, "AI Service Streaming response started", map[string]interface{}{
"user_prefix": userPrefix,
"duration": time.Since(startTime).String(),
})
// Read the streaming response
scanner := bufio.NewScanner(resp.Body)
var chunkCount int
var totalContentLength int
var finalUsage *Usage
// Usage information may or may not be included in streaming response chunks depending on the provider.
// We'll only try to extract usage from chunks for providers that support it in streaming responses.
// For providers that don't support usage in streaming, usage data is available via response.UsageMetadata in non-streaming calls.
for scanner.Scan() {
line := scanner.Text()
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, ":") {
continue
}
// Parse Server-Sent Events format
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// Check for end of stream
if data == "[DONE]" {
break
}
// Parse the JSON chunk
var streamResp OpenAIStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
s.logger.Warn(ctx, "AI Service WARNING: Failed to parse streaming chunk", map[string]interface{}{
"error": err.Error(),
"data": data,
})
span.SetAttributes(attribute.String("stream.result", "chunk_parse_failed"), attribute.String("error", err.Error()), attribute.String("data", data))
continue
}
if streamResp.Error != nil {
span.SetAttributes(attribute.String("stream.result", "api_streaming_error"), attribute.String("error_message", streamResp.Error.Message), attribute.String("error_type", streamResp.Error.Type))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "OpenAI API streaming error: %s", streamResp.Error.Message)
}
// Extract usage information if available (usually in the final chunk)
// Only check for usage if the provider supports it in streaming responses
if streamResp.Usage != nil && s.supportsUsageInStreaming(userConfig.Provider) {
finalUsage = streamResp.Usage
}
// Extract content from the chunk
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" {
content := streamResp.Choices[0].Delta.Content
totalContentLength += len(content)
// Filter out thinking content for thinking models
filteredContent := s.filterThinkingContent(content, model)
if filteredContent != "" {
select {
case chunks <- filteredContent:
chunkCount++
case <-ctx.Done():
span.SetAttributes(attribute.String("stream.result", "context_cancelled"))
return ctx.Err()
}
}
}
// Check if streaming is finished
if len(streamResp.Choices) > 0 && streamResp.Choices[0].FinishReason != nil {
break
}
}
}
if err := scanner.Err(); err != nil {
span.SetAttributes(attribute.String("stream.result", "scanner_error"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "error reading streaming response: %w", err)
}
s.logger.Info(ctx, "AI Service Streaming response completed", map[string]interface{}{
"user_prefix": userPrefix,
"duration": time.Since(startTime).String(),
"chunk_count": chunkCount,
"total_content_length": totalContentLength,
})
// Extract usage information if available and track it internally
if finalUsage != nil {
userID := contextutils.GetUserIDFromContext(ctx)
apiKeyID := contextutils.GetAPIKeyIDFromContext(ctx)
s.trackAIUsage(ctx, userConfig, *finalUsage, userID, apiKeyID)
} else {
// For providers that don't support usage in streaming, this is expected behavior
if !s.supportsUsageInStreaming(userConfig.Provider) {
s.logger.Debug(ctx, "No usage information in streaming response (expected - provider doesn't support usage in streaming)", map[string]any{
"user_prefix": userPrefix,
"chunk_count": chunkCount,
"content_length": totalContentLength,
"provider": userConfig.Provider,
"usage_supported": s.supportsUsageInStreaming(userConfig.Provider),
})
} else {
s.logger.Warn(ctx, "No usage information available in streaming response", map[string]any{
"user_prefix": userPrefix,
"chunk_count": chunkCount,
"content_length": totalContentLength,
"provider": userConfig.Provider,
})
}
span.SetAttributes(attribute.String("stream.result", "no_usage_information"), attribute.Int("chunk_count", chunkCount), attribute.Int("content_length", totalContentLength))
}
span.SetAttributes(attribute.String("stream.result", "success"), attribute.Int("chunk_count", chunkCount), attribute.Int("total_content_length", totalContentLength), attribute.String("duration", time.Since(startTime).String()))
return nil
}
// filterThinkingContent filters out thinking sections for reasoning models
func (s *AIService) filterThinkingContent(content, model string) string {
// Check if this is a thinking/reasoning model
if !s.isThinkingModel(model) {
return content
}
// For thinking models, filter out content between <thinking> tags
if strings.Contains(content, "<thinking>") || strings.Contains(content, "</thinking>") {
return ""
}
if idx := strings.Index(content, "The answer is:"); idx != -1 {
answer := content[idx+len("The answer is:"):]
lines := strings.Split(answer, "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed != "" {
return trimmed
}
}
return ""
}
trimmed := strings.TrimSpace(content)
if strings.HasPrefix(trimmed, "I need to") ||
strings.HasPrefix(trimmed, "Let me think") ||
strings.HasPrefix(trimmed, "First, I'll") {
return ""
}
return content
}
// isThinkingModel checks if the model is a reasoning/thinking model
func (s *AIService) isThinkingModel(model string) bool {
thinkingModels := []string{
"o1-preview",
"o1-mini",
"o1",
"qwen2.5-coder:32b",
"deepseek-r1",
"marco-o1",
"gpt-4",
"gpt-4-turbo",
"claude-3",
}
modelLower := strings.ToLower(model)
for _, thinkingModel := range thinkingModels {
if strings.Contains(modelLower, strings.ToLower(thinkingModel)) {
return true
}
}
return false
}
// cleanJSONResponse extracts JSON from markdown code blocks or returns the original response
func (s *AIService) cleanJSONResponse(ctx context.Context, response, provider string) string {
_, span := observability.TraceAIFunction(ctx, "clean_json_response",
attribute.String("ai.provider", provider),
attribute.Int("response.length", len(response)),
)
defer span.End()
// If the provider supports grammar field, we expect clean JSON
if s.supportsGrammarField(provider) {
return response
}
// For providers that don't support grammar field, clean up markdown code blocks
response = strings.TrimSpace(response)
// Remove markdown code block markers
if strings.HasPrefix(response, "```json") {
response = strings.TrimPrefix(response, "```json")
response = strings.TrimSuffix(response, "```")
} else if strings.HasPrefix(response, "```") {
response = strings.TrimPrefix(response, "```")
response = strings.TrimSuffix(response, "```")
}
return strings.TrimSpace(response)
}
func (s *AIService) parseQuestionsResponse(ctx context.Context, response, language, level string, qType models.QuestionType, provider string) (result0 []*models.Question, err error) {
if s == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "AIService instance is nil")
}
_, span := observability.TraceAIFunction(ctx, "parse_questions_response",
observability.AttributeQuestionType(qType),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("ai.provider", provider),
attribute.Int("response.length", len(response)),
)
defer observability.FinishSpan(span, &err)
defer func() {
if r := recover(); r != nil {
s.logger.Error(ctx, "PANIC in parseQuestionsResponse", nil, map[string]interface{}{
"panic": fmt.Sprintf("%v", r),
"response": response,
"stack": string(debug.Stack()),
})
span.SetAttributes(attribute.String("parse.result", "panic"), attribute.String("panic", fmt.Sprintf("%v", r)))
}
}()
// Validate input parameters
if response == "" {
span.SetAttributes(attribute.String("parse.result", "empty_response"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI provider returned empty response")
}
if language == "" {
span.SetAttributes(attribute.String("parse.result", "empty_language"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "language cannot be empty")
}
if level == "" {
span.SetAttributes(attribute.String("parse.result", "empty_level"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "level cannot be empty")
}
// Clean the response to handle markdown code blocks for providers without grammar support
cleanedResponse := s.cleanJSONResponse(ctx, response, provider)
if cleanedResponse == "" {
span.SetAttributes(attribute.String("parse.result", "empty_cleaned_response"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI provider returned empty response after cleaning")
}
// With grammar field enforcement, we should get clean JSON directly
// No need for complex extraction - just parse the response directly
var questions []map[string]interface{}
if err := json.Unmarshal([]byte(cleanedResponse), &questions); err != nil {
span.SetAttributes(attribute.String("parse.result", "json_unmarshal_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to parse AI response as JSON: %w", err)
}
if len(questions) == 0 {
span.SetAttributes(attribute.String("parse.result", "no_questions_in_response"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI provider returned no questions in response")
}
var result []*models.Question
var validationErrors []string
var skippedCount int
for i, qData := range questions {
if qData == nil {
skippedCount++
span.SetAttributes(attribute.String("parse.result", "nil_question_data"), attribute.Int("question_index", i))
continue
}
question, err := s.createQuestionFromData(ctx, qData, language, level, qType)
if err != nil {
// Try to extract more info about the failure
var failedField, failedValue string
for k, v := range qData {
if v == nil || v == "" {
failedField = k
failedValue = fmt.Sprintf("%v", v)
break
}
}
validationErrors = append(validationErrors, fmt.Sprintf("question %d: %v (field: %s, value: %s)", i+1, err, failedField, failedValue))
span.SetAttributes(attribute.String("parse.result", "question_creation_failed"), attribute.Int("question_index", i), attribute.String("error", err.Error()))
continue
}
if question == nil {
skippedCount++
span.SetAttributes(attribute.String("parse.result", "nil_question_after_creation"), attribute.Int("question_index", i))
continue
}
// Coerce correct_answer to int if it's a float64 (for schema validation)
if m := question.Content; m != nil {
if v, ok := m["correct_answer"]; ok {
switch val := v.(type) {
case float64:
m["correct_answer"] = int(val)
}
}
}
valid, err := s.ValidateQuestionSchema(ctx, qType, question)
if err != nil {
validationErrors = append(validationErrors, fmt.Sprintf("question %d schema validation error: %v", i+1, err))
span.SetAttributes(attribute.String("parse.result", "schema_validation_error"), attribute.Int("question_index", i), attribute.String("error", err.Error()))
}
if !valid {
SchemaValidationMu.Lock()
SchemaValidationFailures[qType]++
if err != nil {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], err.Error())
} else {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], "validation failed")
}
if len(SchemaValidationFailureDetails[qType]) > 10 {
SchemaValidationFailureDetails[qType] = SchemaValidationFailureDetails[qType][len(SchemaValidationFailureDetails[qType])-10:]
}
SchemaValidationMu.Unlock()
skippedCount++
span.SetAttributes(attribute.String("parse.result", "schema_validation_failed"), attribute.Int("question_index", i))
continue // skip invalid question
}
result = append(result, question)
}
// Log validation summary
if len(validationErrors) > 0 {
s.logger.Warn(ctx, "AI Service WARNING: validation errors in response", map[string]interface{}{
"validation_errors_count": len(validationErrors),
"validation_errors": strings.Join(validationErrors, "; "),
})
span.SetAttributes(attribute.String("parse.result", "validation_errors"), attribute.String("errors", strings.Join(validationErrors, "; ")))
}
if len(result) == 0 {
span.SetAttributes(attribute.String("parse.result", "no_valid_questions"), attribute.Int("total_questions", len(questions)), attribute.Int("skipped_count", skippedCount))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "AI provider returned only invalid or empty questions (total: %d, skipped: %d)", len(questions), skippedCount)
}
span.SetAttributes(attribute.String("parse.result", "success"), attribute.Int("valid_questions", len(result)), attribute.Int("total_questions", len(questions)), attribute.Int("skipped_count", skippedCount))
return result, nil
}
// createQuestionFromData creates a Question from parsed JSON data
func (s *AIService) createQuestionFromData(ctx context.Context, data map[string]interface{}, language, level string, qType models.QuestionType) (result0 *models.Question, err error) {
if s == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "AIService instance is nil")
}
_, span := observability.TraceAIFunction(ctx, "create_question_from_data",
observability.AttributeQuestionType(qType),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.Int("data.fields", len(data)),
)
defer observability.FinishSpan(span, &err)
if data == nil {
span.SetAttributes(attribute.String("creation.result", "nil_data"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "question data is nil")
}
// Validate required parameters
if language == "" {
span.SetAttributes(attribute.String("creation.result", "empty_language"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "language cannot be empty")
}
if level == "" {
span.SetAttributes(attribute.String("creation.result", "empty_level"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "level cannot be empty")
}
if ok, errMsg := s.validateQuestionContent(ctx, qType, data); !ok {
missingFields := []string{}
for k, v := range data {
if v == nil || v == "" {
missingFields = append(missingFields, k)
}
}
if len(missingFields) > 0 {
span.SetAttributes(attribute.String("creation.result", "validation_failed_with_missing_fields"), attribute.String("missing_fields", strings.Join(missingFields, ",")))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "invalid question content structure: %s. Missing or empty fields: %v", errMsg, missingFields)
}
span.SetAttributes(attribute.String("creation.result", "validation_failed"), attribute.String("error", errMsg))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "invalid question content structure: %s", errMsg)
}
// Defensive: For reading comprehension, check passage, question, options, correct_answer
if qType == models.ReadingComprehension {
if _, ok := data["passage"].(string); !ok {
span.SetAttributes(attribute.String("creation.result", "reading_missing_passage"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing or invalid 'passage' field")
}
if _, ok := data["question"].(string); !ok {
span.SetAttributes(attribute.String("creation.result", "reading_missing_question"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing or invalid 'question' field")
}
options, ok := data["options"].([]interface{})
if !ok || len(options) != 4 {
span.SetAttributes(attribute.String("creation.result", "reading_invalid_options"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing or invalid 'options' field (must be array of 4 strings)")
}
for i, opt := range options {
if _, ok := opt.(string); !ok {
span.SetAttributes(attribute.String("creation.result", "reading_invalid_option_type"), attribute.Int("option_index", i))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "reading comprehension question 'options' must be array of strings, found invalid type at index %d", i)
}
}
if _, ok := data["correct_answer"]; !ok {
span.SetAttributes(attribute.String("creation.result", "reading_missing_correct_answer"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing 'correct_answer' field")
}
}
// Parse correct_answer as index (integer)
var correctAnswerIndex int
if correctAnswerRaw, exists := data["correct_answer"]; exists {
switch v := correctAnswerRaw.(type) {
case int:
correctAnswerIndex = v
case float64:
correctAnswerIndex = int(v)
case string:
// Handle string indices like "0", "1", "2", "3"
if idx, err := strconv.Atoi(v); err == nil {
correctAnswerIndex = idx
} else {
// Handle answer text - find index in options
if options, ok := data["options"].([]interface{}); ok {
found := false
for i, opt := range options {
if optStr, ok := opt.(string); ok && optStr == v {
correctAnswerIndex = i
found = true
break
}
}
if !found {
span.SetAttributes(attribute.String("creation.result", "correct_answer_not_found_in_options"))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "correct_answer '%s' not found in options", v)
}
} else {
span.SetAttributes(attribute.String("creation.result", "no_options_for_text_answer"))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "correct_answer is text '%s' but no options available to match against", v)
}
}
default:
span.SetAttributes(attribute.String("creation.result", "invalid_correct_answer_type"), attribute.String("type", fmt.Sprintf("%T", v)))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "invalid correct_answer type: %T", v)
}
} else {
span.SetAttributes(attribute.String("creation.result", "missing_correct_answer"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "missing correct_answer field")
}
// Validate correct answer index
if options, ok := data["options"].([]interface{}); ok {
if correctAnswerIndex < 0 || correctAnswerIndex >= len(options) {
span.SetAttributes(attribute.String("creation.result", "invalid_correct_answer_index"), attribute.Int("index", correctAnswerIndex), attribute.Int("options_count", len(options)))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "correct_answer index %d is out of range (0-%d)", correctAnswerIndex, len(options)-1)
}
}
// Note: Removed backend shuffling logic - frontend handles shuffling
// This prevents mismatch between backend and frontend answer indices
// Get explanation or provide default
explanation, _ := data["explanation"].(string)
if explanation == "" {
// Provide a default explanation based on question type
switch qType {
case models.Vocabulary:
explanation = "This vocabulary question tests your knowledge of words in context."
case models.ReadingComprehension:
explanation = "This reading comprehension question tests your understanding of the passage."
case models.FillInBlank:
explanation = "This fill-in-the-blank question tests your grammar and vocabulary knowledge."
case models.QuestionAnswer:
explanation = "This question tests your conversational and practical language skills."
default:
explanation = "This question tests your language skills."
}
// Add the explanation to the data for schema validation
data["explanation"] = explanation
}
question := &models.Question{
Type: qType,
Language: language,
Level: level,
DifficultyScore: s.getDifficultyScore(level),
Content: data,
CorrectAnswer: correctAnswerIndex,
Explanation: explanation,
CreatedAt: time.Now(),
}
span.SetAttributes(attribute.String("creation.result", "success"))
return question, nil
}
func (s *AIService) parseQuestionResponse(ctx context.Context, response, language, level string, qType models.QuestionType, provider string) (result0 *models.Question, err error) {
_, span := observability.TraceAIFunction(ctx, "parse_question_response",
observability.AttributeQuestionType(qType),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("ai.provider", provider),
attribute.Int("response.length", len(response)),
)
defer observability.FinishSpan(span, &err)
// Clean the response to handle markdown code blocks for providers without grammar support
cleanedResponse := s.cleanJSONResponse(ctx, response, provider)
// With grammar field enforcement, we should get clean JSON directly
// No need for complex extraction - just parse the response directly
var data map[string]interface{}
if err := json.Unmarshal([]byte(cleanedResponse), &data); err != nil {
s.logger.Error(ctx, "Failed to parse JSON response", err, map[string]interface{}{
"raw_response": response,
})
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to parse AI response as JSON: %w", err)
}
question, err := s.createQuestionFromData(ctx, data, language, level, qType)
if err != nil {
s.logger.Error(ctx, "Failed to create question from data", err, map[string]interface{}{
"raw_question_data": data,
"full_model_response": response,
})
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to create question: %w", err)
}
valid, err := s.ValidateQuestionSchema(ctx, qType, question)
if err != nil {
s.logger.Error(ctx, "Schema validation error for question", err, nil)
}
if !valid {
SchemaValidationMu.Lock()
SchemaValidationFailures[qType]++
if err != nil {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], err.Error())
} else {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], "validation failed")
}
if len(SchemaValidationFailureDetails[qType]) > 10 {
SchemaValidationFailureDetails[qType] = SchemaValidationFailureDetails[qType][len(SchemaValidationFailureDetails[qType])-10:]
}
SchemaValidationMu.Unlock()
}
return question, nil
}
func (s *AIService) getDifficultyScore(level string) float64 {
// Look up the level in the language levels configuration
if s.cfg != nil && s.cfg.LanguageLevels != nil {
for _, langConfig := range s.cfg.LanguageLevels {
for i, lvl := range langConfig.Levels {
if lvl == level {
// Return a score based on the level's position (0.0 to 1.0)
return float64(i) / float64(len(langConfig.Levels)-1)
}
}
}
}
// Default to middle difficulty if level not found
return 0.5
}
func (s *AIService) validateQuestionContent(ctx context.Context, qType models.QuestionType, content map[string]interface{}) (bool, string) {
_, span := observability.TraceAIFunction(ctx, "validate_question_content",
observability.AttributeQuestionType(qType),
attribute.Int("content.fields", len(content)),
)
defer span.End()
// Validate input parameters
if content == nil {
span.SetAttributes(attribute.String("validation.result", "nil_content"))
return false, "question content cannot be nil"
}
requiredFields := make(map[string]func(interface{}) bool)
isString := func(v interface{}) bool {
if v == nil {
return false
}
_, ok := v.(string)
return ok && v.(string) != ""
}
isStringSlice := func(v interface{}) bool {
if v == nil {
return false
}
if slice, ok := v.([]interface{}); ok {
if len(slice) < 4 {
return false
}
for _, item := range slice {
if item == nil {
return false
}
if _, ok := item.(string); !ok {
return false
}
if item.(string) == "" {
return false
}
}
return true
}
return false
}
isCorrectAnswer := func(v interface{}) bool {
if v == nil {
return false
}
switch val := v.(type) {
case int:
return val >= 0
case float64:
return val >= 0 && val == float64(int(val)) // Must be whole number
case string:
// Accept string indices like "0", "1", "2", "3" or answer text
if _, err := strconv.Atoi(val); err == nil {
return true
}
// Or accept answer text that matches one of the options
if options, ok := content["options"].([]interface{}); ok {
for _, opt := range options {
if optStr, ok := opt.(string); ok && optStr == val {
return true
}
}
}
return false
default:
return false
}
}
switch qType {
case models.Vocabulary:
requiredFields["sentence"] = isString
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[Vocabulary] Validation failed for field '%s': %v", field, content[field])
}
}
sentence, _ := content["sentence"].(string)
targetWord, _ := content["question"].(string)
options, _ := content["options"].([]interface{})
if sentence == "" || targetWord == "" || len(options) != 4 {
span.SetAttributes(attribute.String("validation.result", "vocabulary_structure_failed"))
return false, "[Vocabulary] Validation failed: missing or invalid sentence/question/options"
}
if !strings.Contains(sentence, targetWord) {
span.SetAttributes(attribute.String("validation.result", "vocabulary_word_not_found"))
return false, fmt.Sprintf("[Vocabulary] Validation failed: question '%s' not found in sentence '%s'", targetWord, sentence)
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
case models.ReadingComprehension:
requiredFields["passage"] = isString
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
requiredFields["correct_answer"] = isCorrectAnswer
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[ReadingComprehension] Validation failed for field '%s': %v", field, content[field])
}
}
passage, _ := content["passage"].(string)
if passage == "" {
span.SetAttributes(attribute.String("validation.result", "reading_passage_empty"))
return false, "[ReadingComprehension] Validation failed: passage cannot be empty"
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
case models.FillInBlank:
// Fill-in-blank questions now use multiple choice format like all other types
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
requiredFields["correct_answer"] = isCorrectAnswer
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[FillInBlank] Validation failed for field '%s': %v", field, content[field])
}
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
case models.QuestionAnswer:
// Question-answer questions now use multiple choice format like all other types
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
requiredFields["correct_answer"] = isCorrectAnswer
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[QuestionAnswer] Validation failed for field '%s': %v", field, content[field])
}
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
}
// If we reach here, it's an unknown question type
span.SetAttributes(attribute.String("validation.result", "unknown_type"))
return false, fmt.Sprintf("unknown question type: %v", qType)
}
// GetConcurrencyStats returns current concurrency metrics
func (s *AIService) GetConcurrencyStats() ConcurrencyStats {
s.statsMu.RLock()
s.concurrencyMu.RLock()
defer s.statsMu.RUnlock()
defer s.concurrencyMu.RUnlock()
// Count active requests globally and per user
queuedRequests := 0 // Currently we don't queue, we fail fast
userActiveCount := make(map[string]int)
for username, count := range s.userRequestCount {
if count > 0 {
userActiveCount[username] = count
}
}
return ConcurrencyStats{
ActiveRequests: s.activeRequests,
MaxConcurrent: s.maxConcurrent,
QueuedRequests: queuedRequests,
TotalRequests: s.totalRequests,
UserActiveCount: userActiveCount,
MaxPerUser: s.maxPerUser,
}
}
// acquireGlobalSlot attempts to acquire a global concurrency slot
func (s *AIService) acquireGlobalSlot(ctx context.Context) error {
select {
case s.globalSemaphore <- struct{}{}:
return nil
case <-ctx.Done():
return contextutils.WrapErrorf(contextutils.ErrTimeout, "request cancelled while waiting for global AI slot: %w", ctx.Err())
default:
return contextutils.WrapErrorf(contextutils.ErrServiceUnavailable, "AI service at capacity (%d concurrent requests), please try again", s.maxConcurrent)
}
}
// releaseGlobalSlot releases a global concurrency slot
func (s *AIService) releaseGlobalSlot(ctx context.Context) {
s.concurrencyMu.Lock()
defer s.concurrencyMu.Unlock()
select {
case <-s.globalSemaphore:
// Successfully released a slot
s.statsMu.Lock()
if s.activeRequests > 0 {
s.activeRequests--
}
s.statsMu.Unlock()
default:
// No slot was acquired
s.logger.Warn(ctx, "WARNING: Attempted to release global AI slot but none were acquired", nil)
}
}
// acquireUserSlot acquires a user-specific concurrency slot
func (s *AIService) acquireUserSlot(_ context.Context, username string) error {
s.concurrencyMu.Lock()
defer s.concurrencyMu.Unlock()
currentCount := s.userRequestCount[username]
if currentCount >= s.maxPerUser {
return contextutils.WrapErrorf(contextutils.ErrServiceUnavailable, "user concurrency limit exceeded for %s: %d/%d", username, currentCount, s.maxPerUser)
}
s.userRequestCount[username] = currentCount + 1
return nil
}
// releaseUserSlot releases a user-specific concurrency slot
func (s *AIService) releaseUserSlot(ctx context.Context, username string) {
s.concurrencyMu.Lock()
defer s.concurrencyMu.Unlock()
currentCount := s.userRequestCount[username]
if currentCount > 0 {
s.userRequestCount[username] = currentCount - 1
} else {
s.logger.Warn(ctx, "WARNING: Attempted to release user AI slot but none were acquired", map[string]interface{}{
"username": username,
})
}
}
// incrementTotalRequests increments the total request counter
func (s *AIService) incrementTotalRequests() {
s.statsMu.Lock()
defer s.statsMu.Unlock()
s.totalRequests++
}
// withConcurrencyControl wraps an AI operation with concurrency limits
func (s *AIService) withConcurrencyControl(ctx context.Context, username string, operation func() error) error {
// Check if service is shutting down
if s.isShutdown() {
return contextutils.WrapError(contextutils.ErrServiceUnavailable, "AI service is shutting down")
}
// Increment total request counter
s.incrementTotalRequests()
// Acquire global slot
if err := s.acquireGlobalSlot(ctx); err != nil {
return err
}
// Track active request
s.statsMu.Lock()
s.activeRequests++
s.statsMu.Unlock()
defer func() {
s.releaseGlobalSlot(ctx)
}()
// Acquire per-user slot
if err := s.acquireUserSlot(ctx, username); err != nil {
return err
}
defer s.releaseUserSlot(ctx, username)
// Execute the actual operation
return operation()
}
// supportsGrammarField checks if the provider supports the grammar field
func (s *AIService) supportsGrammarField(provider string) bool {
// Check if the provider supports grammar field
if s.cfg.Providers == nil {
return false
}
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == provider {
return providerConfig.SupportsGrammar
}
}
return false
}
// supportsUsageInStreaming checks if the provider supports usage tracking in streaming responses
func (s *AIService) supportsUsageInStreaming(provider string) bool {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == provider {
return providerConfig.UsageSupported
}
}
return true
}
// getQuestionBatchSize returns the maximum number of questions that can be generated in a single request for the given provider
func (s *AIService) getQuestionBatchSize(provider string) int {
// Get the batch size for the provider
if s.cfg.Providers == nil {
return 1 // Default batch size
}
for _, p := range s.cfg.Providers {
if p.Code == provider {
if p.QuestionBatchSize > 0 {
return p.QuestionBatchSize
}
break
}
}
return 1 // Default batch size
}
// GetQuestionBatchSize returns the maximum number of questions that can be generated in a single request for the given provider
func (s *AIService) GetQuestionBatchSize(provider string) int {
return s.getQuestionBatchSize(provider)
}
// VarietyService returns the variety service used by the AI service
func (s *AIService) VarietyService() *VarietyService {
return s.varietyService
}
// TemplateManager exposes template rendering and example loading for prompts
func (s *AIService) TemplateManager() *AITemplateManager {
return s.templateManager
}
// SupportsGrammarField reports whether the provider supports the grammar field
func (s *AIService) SupportsGrammarField(provider string) bool {
return s.supportsGrammarField(provider)
}
// CallWithPrompt sends a raw prompt (and optional grammar) to the provider and returns the response
func (s *AIService) CallWithPrompt(ctx context.Context, userConfig *models.UserAIConfig, prompt, grammar string) (string, error) {
return s.callOpenAI(ctx, userConfig, prompt, grammar)
}
// trackAIUsage tracks AI usage statistics
func (s *AIService) trackAIUsage(ctx context.Context, userConfig *models.UserAIConfig, usage Usage, userID int, apiKeyID *int) {
// Skip recording if userID is invalid (0 means no user context)
if userID == 0 {
s.logger.Error(ctx, "Skipping AI usage tracking - no valid user ID in context", nil, map[string]interface{}{
"provider": userConfig.Provider,
"model": userConfig.Model,
"prompt_tokens": usage.PromptTokens,
"completion_tokens": usage.CompletionTokens,
"total_tokens": usage.TotalTokens,
})
return
}
// TODO: Determine usage type based on the context (this is a simple heuristic)
usageType := "generic" // Default assumption
// Record usage in the usage stats service
err := s.usageStatsSvc.RecordUserAITokenUsage(
ctx,
userID,
apiKeyID,
userConfig.Provider,
userConfig.Model,
usageType,
usage.PromptTokens,
usage.CompletionTokens,
usage.TotalTokens,
1, // requests
)
if err != nil {
s.logger.Warn(ctx, "Failed to record AI usage", map[string]interface{}{
"error": err.Error(),
"user_id": userID,
})
}
}
// Package services provides embedded templates for AI service prompts
package services
import (
"embed"
"fmt"
"strings"
"text/template"
contextutils "quizapp/internal/utils"
)
//go:embed templates/*.tmpl
var aiTemplatesFS embed.FS
//go:embed templates/examples/*.json
var exampleFilesFS embed.FS
// Template names as constants
const (
BatchQuestionPromptTemplate = "batch_question_prompt.tmpl"
ChatPromptTemplate = "chat_prompt.tmpl"
JSONStructureGuidanceTemplate = "json_structure_guidance.tmpl"
AIFixPromptTemplate = "ai_fix_prompt.tmpl"
)
// AITemplateData holds data for rendering AI prompt templates
type AITemplateData struct {
// Common fields
Language string
Level string
QuestionType string
Topic string
RecentQuestionHistory []string
ReportReasons []string
Count int // For batch generation
// Variety fields for question generation
TopicCategory string
GrammarFocus string
VocabularyDomain string
Scenario string
StyleModifier string
DifficultyModifier string
TimeContext string
// Schema and formatting
SchemaForPrompt string // for direct inclusion in prompt for non-grammar providers
ExampleContent string // for including example in prompt
CurrentQuestionJSON string // the actual question JSON to pass into ai-fix prompt
AdditionalContext string // optional freeform context provided by admin when requesting AI fix
// Explanation specific
Question string
UserAnswer string
CorrectAnswer string // The text of the correct answer for explanations
// Chat specific
Passage string
Options []string
IsCorrect *bool
ConversationHistory []ChatMessage
UserMessage string
// Priority-aware generation fields (NEW)
UserWeakAreas []string
HighPriorityTopics []string
GapAnalysis map[string]int
FocusOnWeakAreas bool
FreshQuestionRatio float64
PriorityDistribution map[string]int
// Story generation fields
Title string
Subject string
AuthorStyle string
TimePeriod string
Genre string
Tone string
CharacterNames string
CustomInstructions string
TargetWords int
TargetSentences int
IsFirstSection bool
PreviousSections string
SectionText string
}
// ChatMessage represents a chat message for templates
type ChatMessage struct {
Role string
Content string
}
// AITemplateManager manages AI prompt templates
type AITemplateManager struct {
templates *template.Template
}
// NewAITemplateManager creates a new template manager
func NewAITemplateManager() (result0 *AITemplateManager, err error) {
templates, err := template.New("").ParseFS(aiTemplatesFS, "templates/*.tmpl")
if err != nil {
return nil, err
}
return &AITemplateManager{
templates: templates,
}, nil
}
// RenderTemplate renders a template with the given data
func (tm *AITemplateManager) RenderTemplate(templateName string, data AITemplateData) (result0 string, err error) {
var buf strings.Builder
err = tm.templates.ExecuteTemplate(&buf, templateName, data)
if err != nil {
return "", err
}
return buf.String(), nil
}
// LoadExample loads the example JSON for a specific question type
func (tm *AITemplateManager) LoadExample(questionType string) (result0 string, err error) {
examplePath := fmt.Sprintf("templates/examples/%s_example.json", questionType)
content, err := exampleFilesFS.ReadFile(examplePath)
if err != nil {
return "", contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to load example for %s: %w", questionType, err)
}
return string(content), nil
}
package services
import (
"context"
"crypto/rand"
"database/sql"
"encoding/hex"
"errors"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"golang.org/x/crypto/bcrypt"
)
// AuthAPIKeyServiceInterface defines the interface for auth API key operations
type AuthAPIKeyServiceInterface interface {
CreateAPIKey(ctx context.Context, userID int, keyName, permissionLevel string) (*models.AuthAPIKey, string, error)
ListAPIKeys(ctx context.Context, userID int) ([]models.AuthAPIKey, error)
GetAPIKeyByID(ctx context.Context, userID, keyID int) (*models.AuthAPIKey, error)
DeleteAPIKey(ctx context.Context, userID, keyID int) error
ValidateAPIKey(ctx context.Context, rawKey string) (*models.AuthAPIKey, error)
UpdateLastUsed(ctx context.Context, keyID int) error
}
// AuthAPIKeyService implements AuthAPIKeyServiceInterface
type AuthAPIKeyService struct {
db *sql.DB
logger *observability.Logger
}
// NewAuthAPIKeyService creates a new AuthAPIKeyService instance
func NewAuthAPIKeyService(db *sql.DB, logger *observability.Logger) *AuthAPIKeyService {
return &AuthAPIKeyService{
db: db,
logger: logger,
}
}
const (
// KeyPrefix is the prefix for all auth API keys
KeyPrefix = "qapp_"
// KeyLength is the length of the random part of the key (32 characters)
KeyLength = 32
)
// generateAPIKey generates a new random API key
func generateAPIKey() (string, error) {
// Generate 32 random bytes
randomBytes := make([]byte, KeyLength/2) // 16 bytes = 32 hex characters
if _, err := rand.Read(randomBytes); err != nil {
return "", contextutils.WrapErrorf(err, "failed to generate random key: %w", err)
}
// Convert to hex string
randomStr := hex.EncodeToString(randomBytes)
// Add prefix
return KeyPrefix + randomStr, nil
}
// hashAPIKey hashes an API key using bcrypt
func hashAPIKey(key string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(key), bcrypt.DefaultCost)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to hash API key: %w", err)
}
return string(hash), nil
}
// CreateAPIKey creates a new API key for a user
func (s *AuthAPIKeyService) CreateAPIKey(ctx context.Context, userID int, keyName, permissionLevel string) (*models.AuthAPIKey, string, error) {
ctx, span := observability.TraceFunction(ctx, "auth_api_key_service", "create_api_key")
defer observability.FinishSpan(span, nil)
span.SetAttributes(
attribute.Int("user_id", userID),
attribute.String("key_name", keyName),
attribute.String("permission_level", permissionLevel),
)
// Validate permission level
if !models.IsValidPermissionLevel(permissionLevel) {
err := contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid permission level",
"Permission level must be 'readonly' or 'full'",
)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, "", err
}
// Validate key name
if keyName == "" {
err := contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Key name is required",
"",
)
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
return nil, "", err
}
// Generate new API key
rawKey, err := generateAPIKey()
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to generate API key")
return nil, "", contextutils.WrapError(err, "failed to generate API key")
}
// Hash the key
keyHash, err := hashAPIKey(rawKey)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to hash API key")
return nil, "", contextutils.WrapError(err, "failed to hash API key")
}
// Extract key prefix (first 12 characters including "qapp_")
keyPrefix := rawKey
if len(rawKey) > 12 {
keyPrefix = rawKey[:12]
}
// Insert into database
query := `
INSERT INTO auth_api_keys (user_id, key_name, key_hash, key_prefix, permission_level, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, created_at, updated_at
`
now := time.Now()
var apiKey models.AuthAPIKey
apiKey.UserID = userID
apiKey.KeyName = keyName
apiKey.KeyHash = keyHash
apiKey.KeyPrefix = keyPrefix
apiKey.PermissionLevel = permissionLevel
err = s.db.QueryRowContext(ctx, query, userID, keyName, keyHash, keyPrefix, permissionLevel, now, now).
Scan(&apiKey.ID, &apiKey.CreatedAt, &apiKey.UpdatedAt)
if err != nil {
s.logger.Error(ctx, "Failed to create API key", err, map[string]interface{}{
"user_id": userID,
"key_name": keyName,
"permission_level": permissionLevel,
})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to insert API key")
return nil, "", contextutils.WrapError(err, "failed to create API key")
}
span.SetAttributes(attribute.Int("api_key_id", apiKey.ID))
s.logger.Info(ctx, "Created new API key", map[string]interface{}{
"user_id": userID,
"api_key_id": apiKey.ID,
"key_name": keyName,
"permission_level": permissionLevel,
})
// Return the API key object and the raw key (only time it's returned)
return &apiKey, rawKey, nil
}
// ListAPIKeys returns all API keys for a user
func (s *AuthAPIKeyService) ListAPIKeys(ctx context.Context, userID int) ([]models.AuthAPIKey, error) {
ctx, span := observability.TraceFunction(ctx, "auth_api_key_service", "list_api_keys")
defer observability.FinishSpan(span, nil)
span.SetAttributes(attribute.Int("user_id", userID))
query := `
SELECT id, user_id, key_name, key_hash, key_prefix, permission_level, last_used_at, created_at, updated_at
FROM auth_api_keys
WHERE user_id = $1
ORDER BY created_at DESC
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
s.logger.Error(ctx, "Failed to list API keys", err, map[string]interface{}{"user_id": userID})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to query API keys")
return nil, contextutils.WrapError(err, "failed to list API keys")
}
defer func() { _ = rows.Close() }()
var apiKeys []models.AuthAPIKey
for rows.Next() {
var apiKey models.AuthAPIKey
err := rows.Scan(
&apiKey.ID,
&apiKey.UserID,
&apiKey.KeyName,
&apiKey.KeyHash,
&apiKey.KeyPrefix,
&apiKey.PermissionLevel,
&apiKey.LastUsedAt,
&apiKey.CreatedAt,
&apiKey.UpdatedAt,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan API key", err, map[string]interface{}{"user_id": userID})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan API key")
return nil, contextutils.WrapError(err, "failed to scan API key")
}
apiKeys = append(apiKeys, apiKey)
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating API keys", err, map[string]interface{}{"user_id": userID})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to iterate API keys")
return nil, contextutils.WrapError(err, "failed to list API keys")
}
span.SetAttributes(attribute.Int("count", len(apiKeys)))
return apiKeys, nil
}
// GetAPIKeyByID retrieves a specific API key by ID for a user
func (s *AuthAPIKeyService) GetAPIKeyByID(ctx context.Context, userID, keyID int) (*models.AuthAPIKey, error) {
ctx, span := observability.TraceFunction(ctx, "auth_api_key_service", "get_api_key_by_id")
defer observability.FinishSpan(span, nil)
span.SetAttributes(
attribute.Int("user_id", userID),
attribute.Int("key_id", keyID),
)
query := `
SELECT id, user_id, key_name, key_hash, key_prefix, permission_level, last_used_at, created_at, updated_at
FROM auth_api_keys
WHERE id = $1 AND user_id = $2
`
var apiKey models.AuthAPIKey
err := s.db.QueryRowContext(ctx, query, keyID, userID).Scan(
&apiKey.ID,
&apiKey.UserID,
&apiKey.KeyName,
&apiKey.KeyHash,
&apiKey.KeyPrefix,
&apiKey.PermissionLevel,
&apiKey.LastUsedAt,
&apiKey.CreatedAt,
&apiKey.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
s.logger.Error(ctx, "Failed to get API key", err, map[string]interface{}{
"user_id": userID,
"key_id": keyID,
})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to get API key")
return nil, contextutils.WrapError(err, "failed to get API key")
}
return &apiKey, nil
}
// DeleteAPIKey deletes an API key
func (s *AuthAPIKeyService) DeleteAPIKey(ctx context.Context, userID, keyID int) error {
ctx, span := observability.TraceFunction(ctx, "auth_api_key_service", "delete_api_key")
defer observability.FinishSpan(span, nil)
span.SetAttributes(
attribute.Int("user_id", userID),
attribute.Int("key_id", keyID),
)
query := `DELETE FROM auth_api_keys WHERE id = $1 AND user_id = $2`
result, err := s.db.ExecContext(ctx, query, keyID, userID)
if err != nil {
s.logger.Error(ctx, "Failed to delete API key", err, map[string]interface{}{
"user_id": userID,
"key_id": keyID,
})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to delete API key")
return contextutils.WrapError(err, "failed to delete API key")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
s.logger.Error(ctx, "Failed to get rows affected", err, map[string]interface{}{
"user_id": userID,
"key_id": keyID,
})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to get rows affected")
return contextutils.WrapError(err, "failed to check deletion")
}
if rowsAffected == 0 {
err := contextutils.NewAppError(
contextutils.ErrorCodeRecordNotFound,
contextutils.SeverityWarn,
"API key not found",
"",
)
span.RecordError(err)
span.SetStatus(codes.Error, "API key not found")
return err
}
s.logger.Info(ctx, "Deleted API key", map[string]interface{}{
"user_id": userID,
"key_id": keyID,
})
return nil
}
// ValidateAPIKey validates a raw API key and returns the associated key info
func (s *AuthAPIKeyService) ValidateAPIKey(ctx context.Context, rawKey string) (*models.AuthAPIKey, error) {
ctx, span := observability.TraceFunction(ctx, "auth_api_key_service", "validate_api_key")
defer observability.FinishSpan(span, nil)
// Basic validation
if rawKey == "" {
return nil, errors.New("API key is empty")
}
if len(rawKey) < len(KeyPrefix) || rawKey[:len(KeyPrefix)] != KeyPrefix {
span.SetStatus(codes.Error, "invalid API key format")
return nil, errors.New("invalid API key format")
}
// Query all API keys with matching prefix for this key
// We need to check all because we hash the keys
query := `
SELECT id, user_id, key_name, key_hash, key_prefix, permission_level, last_used_at, created_at, updated_at
FROM auth_api_keys
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
s.logger.Error(ctx, "Failed to query API keys for validation", err, nil)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to query API keys")
return nil, contextutils.WrapError(err, "failed to validate API key")
}
defer func() { _ = rows.Close() }()
// Check each key by comparing bcrypt hash
for rows.Next() {
var apiKey models.AuthAPIKey
err := rows.Scan(
&apiKey.ID,
&apiKey.UserID,
&apiKey.KeyName,
&apiKey.KeyHash,
&apiKey.KeyPrefix,
&apiKey.PermissionLevel,
&apiKey.LastUsedAt,
&apiKey.CreatedAt,
&apiKey.UpdatedAt,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan API key", err, nil)
continue
}
// Compare hash
err = bcrypt.CompareHashAndPassword([]byte(apiKey.KeyHash), []byte(rawKey))
if err == nil {
// Found matching key
span.SetAttributes(
attribute.Int("api_key_id", apiKey.ID),
attribute.Int("user_id", apiKey.UserID),
attribute.String("permission_level", apiKey.PermissionLevel),
)
return &apiKey, nil
}
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating API keys", err, nil)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to iterate API keys")
return nil, contextutils.WrapError(err, "failed to validate API key")
}
// No matching key found
span.SetStatus(codes.Error, "invalid API key")
return nil, errors.New("invalid API key")
}
// UpdateLastUsed updates the last_used_at timestamp for an API key
// This should be called asynchronously to avoid blocking requests
func (s *AuthAPIKeyService) UpdateLastUsed(ctx context.Context, keyID int) error {
ctx, span := observability.TraceFunction(ctx, "auth_api_key_service", "update_last_used")
defer observability.FinishSpan(span, nil)
span.SetAttributes(attribute.Int("key_id", keyID))
query := `UPDATE auth_api_keys SET last_used_at = $1, updated_at = $2 WHERE id = $3`
now := time.Now()
_, err := s.db.ExecContext(ctx, query, now, now, keyID)
if err != nil {
s.logger.Error(ctx, "Failed to update last used timestamp", err, map[string]interface{}{
"key_id": keyID,
})
span.RecordError(err)
span.SetStatus(codes.Error, "failed to update last used")
// Don't return error - this is not critical
return nil
}
return nil
}
package services
import (
"context"
"database/sql"
"errors"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"quizapp/internal/observability"
)
// CleanupService handles database maintenance and cleanup tasks
type CleanupService struct {
db *sql.DB
logger *observability.Logger
}
// NewCleanupServiceWithLogger creates a new cleanup service with logger
func NewCleanupServiceWithLogger(db *sql.DB, logger *observability.Logger) *CleanupService {
return &CleanupService{
db: db,
logger: logger,
}
}
// CleanupLegacyQuestionTypes removes questions with unsupported question types
func (c *CleanupService) CleanupLegacyQuestionTypes(ctx context.Context) (err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "cleanup_legacy_question_types")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if database is available
if c.db == nil {
return errors.New("database connection not available")
}
// Get count of legacy questions first
var count int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM questions
WHERE type NOT IN ('vocabulary', 'fill_blank', 'qa', 'reading_comprehension')
`).Scan(&count)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(attribute.Int("cleanup.legacy_questions_count", count))
if count == 0 {
c.logger.Info(ctx, "No legacy question types found to cleanup", map[string]interface{}{})
span.SetAttributes(attribute.String("cleanup.result", "no_legacy_questions"))
return nil
}
c.logger.Info(ctx, "Found questions with legacy types to cleanup", map[string]interface{}{"count": count})
// Delete questions with unsupported types
result, err := c.db.ExecContext(ctx, `
DELETE FROM questions
WHERE type NOT IN ('vocabulary', 'fill_blank', 'qa', 'reading_comprehension')
`)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(
attribute.Int64("cleanup.rows_affected", rowsAffected),
attribute.String("cleanup.result", "success"),
)
c.logger.Info(ctx, "Successfully cleaned up questions with legacy types", map[string]interface{}{"rows_affected": rowsAffected})
return nil
}
// CleanupOrphanedResponses removes user responses for questions that no longer exist
func (c *CleanupService) CleanupOrphanedResponses(ctx context.Context) (err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "cleanup_orphaned_responses")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if database is available
if c.db == nil {
return errors.New("database connection not available")
}
var count int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM user_responses ur
LEFT JOIN questions q ON ur.question_id = q.id
WHERE q.id IS NULL
`).Scan(&count)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(attribute.Int("cleanup.orphaned_responses_count", count))
if count == 0 {
c.logger.Info(ctx, "No orphaned responses found to cleanup", map[string]interface{}{})
span.SetAttributes(attribute.String("cleanup.result", "no_orphaned_responses"))
return nil
}
c.logger.Info(ctx, "Found orphaned responses to cleanup", map[string]interface{}{"count": count})
result, err := c.db.ExecContext(ctx, `
DELETE FROM user_responses
WHERE question_id NOT IN (SELECT id FROM questions)
`)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(
attribute.Int64("cleanup.rows_affected", rowsAffected),
attribute.String("cleanup.result", "success"),
)
c.logger.Info(ctx, "Successfully cleaned up orphaned responses", map[string]interface{}{"rows_affected": rowsAffected})
return nil
}
// RunFullCleanup performs all cleanup operations
func (c *CleanupService) RunFullCleanup(ctx context.Context) (err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "run_full_cleanup")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
span.SetAttributes(attribute.String("cleanup.start_time", time.Now().Format(time.RFC3339)))
c.logger.Info(ctx, "Starting database cleanup", map[string]interface{}{"start_time": time.Now().Format(time.RFC3339)})
if err = c.CleanupLegacyQuestionTypes(ctx); err != nil {
c.logger.Error(ctx, "Failed to cleanup legacy question types", err, map[string]interface{}{})
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
if err := c.CleanupOrphanedResponses(ctx); err != nil {
c.logger.Error(ctx, "Failed to cleanup orphaned responses", err, map[string]interface{}{})
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(
attribute.String("cleanup.end_time", time.Now().Format(time.RFC3339)),
attribute.String("cleanup.result", "success"),
)
c.logger.Info(ctx, "Database cleanup completed successfully", map[string]interface{}{"end_time": time.Now().Format(time.RFC3339)})
return nil
}
// GetCleanupStats returns statistics about cleanup operations
func (c *CleanupService) GetCleanupStats(ctx context.Context) (result0 map[string]int, err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "get_cleanup_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if database is available
if c.db == nil {
return nil, errors.New("database connection not available")
}
stats := make(map[string]int)
// Count legacy question types
var legacyCount int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM questions
WHERE type NOT IN ('vocabulary', 'fill_blank', 'qa', 'reading_comprehension')
`).Scan(&legacyCount)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, err
}
stats["legacy_questions"] = legacyCount
// Count orphaned responses
var orphanedCount int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM user_responses ur
LEFT JOIN questions q ON ur.question_id = q.id
WHERE q.id IS NULL
`).Scan(&orphanedCount)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, err
}
stats["orphaned_responses"] = orphanedCount
span.SetAttributes(
attribute.Int("cleanup.stats.legacy_questions", legacyCount),
attribute.Int("cleanup.stats.orphaned_responses", orphanedCount),
)
return stats, nil
}
package services
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"quizapp/internal/api"
contextutils "quizapp/internal/utils"
"github.com/google/uuid"
)
// ConversationServiceInterface defines the interface for AI conversation operations
type ConversationServiceInterface interface {
// Conversation CRUD operations
CreateConversation(ctx context.Context, userID uint, req *api.CreateConversationRequest) (*api.Conversation, error)
GetConversation(ctx context.Context, conversationID string, userID uint) (*api.Conversation, error)
GetUserConversations(ctx context.Context, userID uint, limit, offset int) ([]api.Conversation, int, error)
UpdateConversation(ctx context.Context, conversationID string, userID uint, req *api.UpdateConversationRequest) (*api.Conversation, error)
DeleteConversation(ctx context.Context, conversationID string, userID uint) error
// Message operations
AddMessage(ctx context.Context, conversationID string, userID uint, req *api.CreateMessageRequest) (*api.ChatMessage, error)
GetConversationMessages(ctx context.Context, conversationID string, userID uint) ([]api.ChatMessage, error)
ToggleMessageBookmark(ctx context.Context, conversationID, messageID string, userID uint) (bool, error)
// Search operations
SearchMessages(ctx context.Context, userID uint, query string, limit, offset int) ([]api.ChatMessage, int, error)
SearchConversations(ctx context.Context, userID uint, query string, limit, offset int) ([]api.Conversation, int, error)
// Bookmark operations
GetBookmarkedMessages(ctx context.Context, userID uint, query string, limit, offset int) ([]api.ChatMessage, int, error)
// Utility operations
// GetUserMessageCounts returns a map of conversation ID -> message count for the user's conversations
GetUserMessageCounts(ctx context.Context, userID uint) (map[string]int, error)
}
// ConversationService handles all AI conversation-related operations
type ConversationService struct {
db *sql.DB
}
// NewConversationService creates a new ConversationService
func NewConversationService(db *sql.DB) *ConversationService {
return &ConversationService{
db: db,
}
}
// CreateConversation creates a new AI conversation
func (s *ConversationService) CreateConversation(ctx context.Context, userID uint, req *api.CreateConversationRequest) (*api.Conversation, error) {
conversationID := uuid.New()
query := `
INSERT INTO ai_conversations (id, user_id, title, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, user_id, title, created_at, updated_at`
var conversation api.Conversation
err := s.db.QueryRowContext(ctx, query,
conversationID,
userID,
req.Title,
time.Now(),
time.Now(),
).Scan(
&conversation.Id,
&conversation.UserId,
&conversation.Title,
&conversation.CreatedAt,
&conversation.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to create conversation")
}
return &conversation, nil
}
// GetConversation retrieves a conversation with all its messages
func (s *ConversationService) GetConversation(ctx context.Context, conversationID string, userID uint) (*api.Conversation, error) {
// First get the conversation
query := `
SELECT id, user_id, title, created_at, updated_at
FROM ai_conversations
WHERE id = $1 AND user_id = $2`
var conversation api.Conversation
err := s.db.QueryRowContext(ctx, query, conversationID, userID).Scan(
&conversation.Id,
&conversation.UserId,
&conversation.Title,
&conversation.CreatedAt,
&conversation.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrorWithContextf("conversation not found")
}
return nil, contextutils.WrapError(err, "failed to get conversation")
}
// Get the messages for this conversation
messages, err := s.GetConversationMessages(ctx, conversationID, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get conversation messages")
}
// Ensure messages is never nil - always point to a valid slice
if messages == nil {
messages = []api.ChatMessage{}
}
conversation.Messages = &messages
return &conversation, nil
}
// GetUserConversations retrieves all conversations for a user with pagination
func (s *ConversationService) GetUserConversations(ctx context.Context, userID uint, limit, offset int) ([]api.Conversation, int, error) {
// Get total count
countQuery := `SELECT COUNT(*) FROM ai_conversations WHERE user_id = $1`
var total int
err := s.db.QueryRowContext(ctx, countQuery, userID).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count conversations")
}
// Get conversations with pagination
query := `
SELECT id, user_id, title, created_at, updated_at
FROM ai_conversations
WHERE user_id = $1
ORDER BY updated_at DESC
LIMIT $2 OFFSET $3`
rows, err := s.db.QueryContext(ctx, query, userID, limit, offset)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to query conversations")
}
defer func() { _ = rows.Close() }()
var conversations []api.Conversation
for rows.Next() {
var conv api.Conversation
err := rows.Scan(
&conv.Id,
&conv.UserId,
&conv.Title,
&conv.CreatedAt,
&conv.UpdatedAt,
)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to scan conversation")
}
conversations = append(conversations, conv)
}
if err := rows.Err(); err != nil {
return nil, 0, contextutils.WrapError(err, "error iterating conversations")
}
return conversations, total, nil
}
// GetUserMessageCounts returns message counts for all conversations for a user
func (s *ConversationService) GetUserMessageCounts(ctx context.Context, userID uint) (map[string]int, error) {
query := `
SELECT c.id::text AS id, COUNT(m.id) AS message_count
FROM ai_conversations c
LEFT JOIN ai_chat_messages m ON m.conversation_id = c.id
WHERE c.user_id = $1
GROUP BY c.id`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query message counts")
}
defer func() { _ = rows.Close() }()
counts := make(map[string]int)
for rows.Next() {
var id string
var count int
if err := rows.Scan(&id, &count); err != nil {
return nil, contextutils.WrapError(err, "failed to scan message count")
}
counts[id] = count
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating message counts")
}
return counts, nil
}
// UpdateConversation updates a conversation's title
func (s *ConversationService) UpdateConversation(ctx context.Context, conversationID string, userID uint, req *api.UpdateConversationRequest) (*api.Conversation, error) {
query := `
UPDATE ai_conversations
SET title = $1, updated_at = $2
WHERE id = $3 AND user_id = $4
RETURNING id, user_id, title, created_at, updated_at`
var conversation api.Conversation
err := s.db.QueryRowContext(ctx, query,
req.Title,
time.Now(),
conversationID,
userID,
).Scan(
&conversation.Id,
&conversation.UserId,
&conversation.Title,
&conversation.CreatedAt,
&conversation.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrorWithContextf("conversation not found")
}
return nil, contextutils.WrapError(err, "failed to update conversation")
}
return &conversation, nil
}
// DeleteConversation deletes a conversation and all its messages
func (s *ConversationService) DeleteConversation(ctx context.Context, conversationID string, userID uint) error {
// First verify the conversation belongs to the user
var ownerID uint
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM ai_conversations WHERE id = $1", conversationID).Scan(&ownerID)
if err != nil {
if err == sql.ErrNoRows {
return contextutils.ErrorWithContextf("conversation not found")
}
return contextutils.WrapError(err, "failed to verify conversation ownership")
}
if ownerID != userID {
return contextutils.ErrorWithContextf("conversation not found")
}
// Delete the conversation (CASCADE will delete associated messages)
query := `DELETE FROM ai_conversations WHERE id = $1 AND user_id = $2`
result, err := s.db.ExecContext(ctx, query, conversationID, userID)
if err != nil {
return contextutils.WrapError(err, "failed to delete conversation")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrorWithContextf("conversation not found")
}
return nil
}
// AddMessage adds a new message to a conversation
func (s *ConversationService) AddMessage(ctx context.Context, conversationID string, userID uint, req *api.CreateMessageRequest) (*api.ChatMessage, error) {
// First verify the conversation belongs to the user
var ownerID uint
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM ai_conversations WHERE id = $1", conversationID).Scan(&ownerID)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrorWithContextf("conversation not found")
}
return nil, contextutils.WrapError(err, "failed to verify conversation ownership")
}
if ownerID != userID {
return nil, contextutils.ErrorWithContextf("conversation not found")
}
messageID := uuid.New()
query := `
INSERT INTO ai_chat_messages (id, conversation_id, question_id, role, answer_json, bookmarked, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, conversation_id, question_id, role, answer_json, bookmarked, created_at, updated_at`
var message api.ChatMessage
var questionIDPtr *int
if req.QuestionId != nil {
questionIDPtr = req.QuestionId
}
// Store content directly as JSON string
contentJSON, err := json.Marshal(req.Content)
if err != nil {
return nil, contextutils.WrapError(err, "failed to marshal message content")
}
var contentBytes []byte
err = s.db.QueryRowContext(ctx, query,
messageID,
conversationID,
questionIDPtr,
string(req.Role),
contentJSON, // Store as JSON string value
false, // bookmarked defaults to false
time.Now(),
time.Now(),
).Scan(
&message.Id,
&message.ConversationId,
&message.QuestionId,
&message.Role,
&contentBytes,
&message.Bookmarked,
&message.CreatedAt,
&message.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to add message")
}
// Unmarshal the content from bytes
var contentObj struct {
Text *string `json:"text,omitempty"`
}
err = json.Unmarshal(contentBytes, &contentObj)
if err != nil {
return nil, contextutils.WrapError(err, "failed to unmarshal message content")
}
message.Content = contentObj
return &message, nil
}
// GetConversationMessages retrieves all messages for a conversation
func (s *ConversationService) GetConversationMessages(ctx context.Context, conversationID string, userID uint) ([]api.ChatMessage, error) {
// First verify the conversation belongs to the user
var ownerID uint
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM ai_conversations WHERE id = $1", conversationID).Scan(&ownerID)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrorWithContextf("conversation not found")
}
return nil, contextutils.WrapError(err, "failed to verify conversation ownership")
}
if ownerID != userID {
return nil, contextutils.ErrorWithContextf("conversation not found")
}
query := `
SELECT id, conversation_id, question_id, role, answer_json, bookmarked, created_at, updated_at
FROM ai_chat_messages
WHERE conversation_id = $1
ORDER BY created_at ASC`
rows, err := s.db.QueryContext(ctx, query, conversationID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query messages")
}
defer func() { _ = rows.Close() }()
var messages []api.ChatMessage
for rows.Next() {
var msg api.ChatMessage
var questionIDPtr *int
var answerBytes []byte
err := rows.Scan(
&msg.Id,
&msg.ConversationId,
&questionIDPtr,
&msg.Role,
&answerBytes,
&msg.Bookmarked,
&msg.CreatedAt,
&msg.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan message")
}
// Content is now stored as an object, unmarshal accordingly
var contentObj struct {
Text *string `json:"text,omitempty"`
}
err = json.Unmarshal(answerBytes, &contentObj)
if err != nil {
return nil, contextutils.WrapError(err, "failed to unmarshal message content")
}
msg.Content = contentObj
if err != nil {
return nil, contextutils.WrapError(err, "failed to unmarshal message content")
}
if questionIDPtr != nil {
msg.QuestionId = questionIDPtr
}
messages = append(messages, msg)
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating messages")
}
return messages, nil
}
// ToggleMessageBookmark toggles the bookmark status of a message
func (s *ConversationService) ToggleMessageBookmark(ctx context.Context, conversationID, messageID string, userID uint) (bool, error) {
// First verify the conversation belongs to the user
var ownerID uint
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM ai_conversations WHERE id = $1", conversationID).Scan(&ownerID)
if err != nil {
if err == sql.ErrNoRows {
return false, contextutils.ErrorWithContextf("conversation not found")
}
return false, contextutils.WrapError(err, "failed to verify conversation ownership")
}
if ownerID != userID {
return false, contextutils.ErrorWithContextf("conversation not found")
}
// Get current bookmark status and toggle it
var currentBookmarked bool
err = s.db.QueryRowContext(ctx,
"SELECT bookmarked FROM ai_chat_messages WHERE id = $1 AND conversation_id = $2",
messageID, conversationID).Scan(¤tBookmarked)
if err != nil {
if err == sql.ErrNoRows {
return false, contextutils.ErrorWithContextf("message not found")
}
return false, contextutils.WrapError(err, "failed to get message bookmark status")
}
newBookmarked := !currentBookmarked
// Update the bookmark status
query := `UPDATE ai_chat_messages SET bookmarked = $1, updated_at = $2 WHERE id = $3 AND conversation_id = $4`
_, err = s.db.ExecContext(ctx, query, newBookmarked, time.Now(), messageID, conversationID)
if err != nil {
return false, contextutils.WrapError(err, "failed to update message bookmark status")
}
return newBookmarked, nil
}
// SearchMessages searches across all messages for a user
func (s *ConversationService) SearchMessages(ctx context.Context, userID uint, query string, limit, offset int) ([]api.ChatMessage, int, error) {
// Clean and prepare the search query
searchQuery := strings.TrimSpace(query)
if searchQuery == "" {
return nil, 0, contextutils.ErrorWithContextf("search query cannot be empty")
}
// Search in the answer_json column (which contains the message content as JSON string)
// We need to search within the JSON string value, so we search for the pattern within quotes
searchTerm := fmt.Sprintf("%%%s%%", strings.ToLower(searchQuery))
// Get total count of matching messages
countQuery := `
SELECT COUNT(*)
FROM ai_chat_messages m
JOIN ai_conversations c ON m.conversation_id = c.id
WHERE c.user_id = $1 AND LOWER(m.answer_json::text) LIKE $2`
var total int
err := s.db.QueryRowContext(ctx, countQuery, userID, searchTerm).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count search results")
}
// Get messages with conversation titles
querySQL := `
SELECT m.id, m.conversation_id, m.question_id, m.role, m.answer_json::text, m.bookmarked, m.created_at, m.updated_at, c.title
FROM ai_chat_messages m
JOIN ai_conversations c ON m.conversation_id = c.id
WHERE c.user_id = $1 AND LOWER(m.answer_json::text) LIKE $2
ORDER BY m.created_at DESC
LIMIT $3 OFFSET $4`
rows, err := s.db.QueryContext(ctx, querySQL, userID, searchTerm, limit, offset)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to search messages")
}
defer func() { _ = rows.Close() }()
var messages []api.ChatMessage
for rows.Next() {
var msg api.ChatMessage
var questionIDPtr *int
var conversationTitle string
var answerBytes []byte
err := rows.Scan(
&msg.Id,
&msg.ConversationId,
&questionIDPtr,
&msg.Role,
&answerBytes,
&msg.Bookmarked,
&msg.CreatedAt,
&msg.UpdatedAt,
&conversationTitle,
)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to scan search result")
}
// Content is now stored as an object, unmarshal accordingly
var contentObj struct {
Text *string `json:"text,omitempty"`
}
err = json.Unmarshal(answerBytes, &contentObj)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to unmarshal message content")
}
msg.Content = contentObj
if questionIDPtr != nil {
msg.QuestionId = questionIDPtr
}
// Content is retrieved directly as text using ->> operator
// Set conversation title for search results
msg.ConversationTitle = &conversationTitle
messages = append(messages, msg)
}
if err := rows.Err(); err != nil {
return nil, 0, contextutils.WrapError(err, "error iterating search results")
}
return messages, total, nil
}
// SearchConversations searches across all conversations for a user
func (s *ConversationService) SearchConversations(ctx context.Context, userID uint, query string, limit, offset int) ([]api.Conversation, int, error) {
// Clean and prepare the search query
searchQuery := strings.TrimSpace(query)
if searchQuery == "" {
return nil, 0, contextutils.ErrorWithContextf("search query cannot be empty")
}
// Search in both conversation titles and message content
searchTerm := fmt.Sprintf("%%%s%%", strings.ToLower(searchQuery))
// Get total count of matching conversations
countQuery := `
SELECT COUNT(DISTINCT c.id)
FROM ai_conversations c
LEFT JOIN ai_chat_messages m ON c.id = m.conversation_id
WHERE c.user_id = $1
AND (LOWER(c.title) LIKE $2 OR LOWER(m.answer_json::text) LIKE $2)`
var total int
err := s.db.QueryRowContext(ctx, countQuery, userID, searchTerm).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count search results")
}
// Get conversations with their latest message info
querySQL := `
SELECT DISTINCT c.id, c.title, c.created_at, c.updated_at,
(SELECT COUNT(*) FROM ai_chat_messages WHERE conversation_id = c.id) as message_count,
(SELECT answer_json::text FROM ai_chat_messages WHERE conversation_id = c.id ORDER BY created_at ASC LIMIT 1) as first_message,
(SELECT answer_json::text FROM ai_chat_messages WHERE conversation_id = c.id ORDER BY created_at DESC LIMIT 1) as last_message
FROM ai_conversations c
LEFT JOIN ai_chat_messages m ON c.id = m.conversation_id
WHERE c.user_id = $1
AND (LOWER(c.title) LIKE $2 OR LOWER(m.answer_json::text) LIKE $2)
ORDER BY c.updated_at DESC
LIMIT $3 OFFSET $4`
rows, err := s.db.QueryContext(ctx, querySQL, userID, searchTerm, limit, offset)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to search conversations")
}
defer func() { _ = rows.Close() }()
var conversations []api.Conversation
for rows.Next() {
var conv api.Conversation
var firstMessagePtr, lastMessagePtr *string
var messageCount int
err := rows.Scan(
&conv.Id,
&conv.Title,
&conv.CreatedAt,
&conv.UpdatedAt,
&messageCount,
&firstMessagePtr,
&lastMessagePtr,
)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to scan search result")
}
// Set the preview message to the last message if available, otherwise the first message
previewMessage := ""
if lastMessagePtr != nil {
previewMessage = *lastMessagePtr
} else if firstMessagePtr != nil {
previewMessage = *firstMessagePtr
}
// For search results, we need to create a minimal content object
contentObj := struct {
Text *string `json:"text,omitempty"`
}{
Text: &previewMessage,
}
// Add preview_message field for frontend compatibility
conv.Messages = &[]api.ChatMessage{
{
Content: contentObj,
},
}
conversations = append(conversations, conv)
}
if err := rows.Err(); err != nil {
return nil, 0, contextutils.WrapError(err, "error iterating search results")
}
return conversations, total, nil
}
// GetBookmarkedMessages retrieves all bookmarked messages for a user
func (s *ConversationService) GetBookmarkedMessages(ctx context.Context, userID uint, query string, limit, offset int) ([]api.ChatMessage, int, error) {
// Clean and prepare the search query if provided
searchTerm := "%"
if query != "" {
searchQuery := strings.TrimSpace(query)
searchTerm = fmt.Sprintf("%%%s%%", strings.ToLower(searchQuery))
}
// Get total count of bookmarked messages
countQuery := `
SELECT COUNT(*)
FROM ai_chat_messages m
JOIN ai_conversations c ON m.conversation_id = c.id
WHERE c.user_id = $1 AND m.bookmarked = true AND LOWER(m.answer_json::text) LIKE $2`
var total int
err := s.db.QueryRowContext(ctx, countQuery, userID, searchTerm).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count bookmarked messages")
}
// Get bookmarked messages with conversation titles
querySQL := `
SELECT m.id, m.conversation_id, m.question_id, m.role, m.answer_json::text, m.bookmarked, m.created_at, m.updated_at, c.title
FROM ai_chat_messages m
JOIN ai_conversations c ON m.conversation_id = c.id
WHERE c.user_id = $1 AND m.bookmarked = true AND LOWER(m.answer_json::text) LIKE $2
ORDER BY m.created_at DESC
LIMIT $3 OFFSET $4`
rows, err := s.db.QueryContext(ctx, querySQL, userID, searchTerm, limit, offset)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to get bookmarked messages")
}
defer func() { _ = rows.Close() }()
var messages []api.ChatMessage
for rows.Next() {
var msg api.ChatMessage
var questionIDPtr *int
var conversationTitle string
var answerBytes []byte
err := rows.Scan(
&msg.Id,
&msg.ConversationId,
&questionIDPtr,
&msg.Role,
&answerBytes,
&msg.Bookmarked,
&msg.CreatedAt,
&msg.UpdatedAt,
&conversationTitle,
)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to scan bookmarked message")
}
// Content is stored as an object, unmarshal accordingly
var contentObj struct {
Text *string `json:"text,omitempty"`
}
err = json.Unmarshal(answerBytes, &contentObj)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to unmarshal message content")
}
msg.Content = contentObj
if questionIDPtr != nil {
msg.QuestionId = questionIDPtr
}
// Set conversation title for display
msg.ConversationTitle = &conversationTitle
messages = append(messages, msg)
}
if err := rows.Err(); err != nil {
return nil, 0, contextutils.WrapError(err, "error iterating bookmarked messages")
}
return messages, total, nil
}
package services
import (
"context"
"database/sql"
"fmt"
"time"
"quizapp/internal/api"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// DailyQuestionServiceInterface defines the interface for daily question operations
type DailyQuestionServiceInterface interface {
AssignDailyQuestions(ctx context.Context, userID int, date time.Time) error
RegenerateDailyQuestions(ctx context.Context, userID int, date time.Time) error
GetDailyQuestions(ctx context.Context, userID int, date time.Time) ([]*models.DailyQuestionAssignmentWithQuestion, error)
MarkQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) error
ResetQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) error
SubmitDailyQuestionAnswer(ctx context.Context, userID, questionID int, date time.Time, userAnswerIndex int) (*api.AnswerResponse, error)
GetAvailableDates(ctx context.Context, userID int) ([]time.Time, error)
GetDailyProgress(ctx context.Context, userID int, date time.Time) (*models.DailyProgress, error)
GetDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (int, error)
GetCompletedDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (int, error)
GetQuestionHistory(ctx context.Context, userID, questionID, days int) ([]*models.DailyQuestionHistory, error)
}
// DailyQuestionService implements daily question assignment and management
type DailyQuestionService struct {
db *sql.DB
logger *observability.Logger
questionService QuestionServiceInterface
learningService LearningServiceInterface
}
// NewDailyQuestionService creates a new DailyQuestionService instance
func NewDailyQuestionService(db *sql.DB, logger *observability.Logger, questionService QuestionServiceInterface, learningService LearningServiceInterface) *DailyQuestionService {
return &DailyQuestionService{
db: db,
logger: logger,
questionService: questionService,
learningService: learningService,
}
}
// AssignDailyQuestions assigns 10 random questions to a user for a specific date
func (s *DailyQuestionService) AssignDailyQuestions(ctx context.Context, userID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "AssignDailyQuestions",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user to determine language and level preferences
user, err := s.getUserByID(ctx, userID)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get user")
}
if user == nil {
return contextutils.ErrorWithContextf("user not found: %d", userID)
}
span.SetAttributes(attribute.String("user.name", user.Username))
language := user.PreferredLanguage.String
level := user.CurrentLevel.String
if language == "" || level == "" {
return contextutils.ErrorWithContextf("user missing language or level preferences")
}
// Get user's daily goal from learning preferences
prefs, perr := s.learningService.GetUserLearningPreferences(ctx, userID)
if perr != nil {
span.RecordError(perr)
return contextutils.WrapError(perr, "failed to get user learning preferences")
}
goal := 10
if prefs != nil && prefs.DailyGoal > 0 {
goal = prefs.DailyGoal
}
// Check existing assignments and only fill missing slots up to the user's goal
existingCount, err := s.GetDailyQuestionsCount(ctx, userID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to check existing assignments")
}
if existingCount >= goal {
// s.logger.Info(ctx, "Daily questions already assigned for date", map[string]interface{}{
// "user_id": userID,
// "date": date.Format("2006-01-02"),
// "count": existingCount,
// "goal": goal,
// })
return nil // Already assigned
}
// Request more candidates than strictly needed to allow filtering out already-assigned questions
buffer := 10 // request this many extra candidates beyond the user's goal
reqLimit := goal + buffer
// Get adaptive questions using an expanded limit so we can filter and still meet goal
questionsWithStats, err := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, reqLimit)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get adaptive questions for assignment")
}
if len(questionsWithStats) == 0 {
// Gather diagnostics to explain why no questions were available
var candidateIDs []int
candidateCount := 0
totalMatching := 0
if s.questionService != nil {
if candidates, qerr := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, 50); qerr == nil && candidates != nil {
candidateCount = len(candidates)
for i, q := range candidates {
if i >= 10 {
break
}
if q != nil {
candidateIDs = append(candidateIDs, q.ID)
}
}
}
if _, total, terr := s.questionService.GetAllQuestionsPaginated(ctx, 1, 1, "", "", "", language, level, nil); terr == nil {
totalMatching = total
}
}
return &NoQuestionsAvailableError{
Language: language,
Level: level,
CandidateIDs: candidateIDs,
CandidateCount: candidateCount,
TotalMatching: totalMatching,
}
}
// Filter out questions that are already assigned for this user/date to
// avoid selecting already-inserted questions and thus underfilling the goal.
assignedIDs := make(map[int]bool)
rows, qerr := s.db.QueryContext(ctx, `SELECT question_id FROM daily_question_assignments WHERE user_id = $1 AND assignment_date = $2`, userID, date)
if qerr == nil {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
for rows.Next() {
var qid int
if err := rows.Scan(&qid); err == nil {
assignedIDs[qid] = true
}
}
}
// Convert QuestionWithStats to Question for assignment, skipping already-assigned
var questions []models.Question
for _, qws := range questionsWithStats {
if qws == nil || qws.Question == nil {
continue
}
if assignedIDs[qws.ID] {
// already assigned for this date, skip
continue
}
questions = append(questions, *qws.Question)
}
// Only insert up to the number of slots we need to fill
toAssign := goal - existingCount
if toAssign < 0 {
toAssign = 0
}
if len(questions) > toAssign {
questions = questions[:toAssign]
}
// Begin transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error(ctx, "Failed to rollback transaction", rollbackErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}
}()
// Insert assignments (idempotent via conditional INSERT to avoid duplicate rows)
insertQuery := `
INSERT INTO daily_question_assignments (user_id, question_id, assignment_date, created_at)
SELECT $1, $2, $3, $4
WHERE NOT EXISTS (
SELECT 1 FROM daily_question_assignments WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3
)
`
for _, question := range questions {
_, err = tx.ExecContext(ctx, insertQuery, userID, question.ID, date, time.Now())
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to insert assignment")
}
}
// Commit transaction
err = tx.Commit()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to commit transaction")
}
s.logger.Info(ctx, "Daily questions assigned successfully", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"count": len(questions),
})
return nil
}
// RegenerateDailyQuestions clears existing daily question assignments and creates new ones for a user and date
func (s *DailyQuestionService) RegenerateDailyQuestions(ctx context.Context, userID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "RegenerateDailyQuestions",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user to determine language and level preferences
user, err := s.getUserByID(ctx, userID)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get user")
}
if user == nil {
return contextutils.ErrorWithContextf("user not found: %d", userID)
}
language := user.PreferredLanguage.String
level := user.CurrentLevel.String
if language == "" || level == "" {
return contextutils.ErrorWithContextf("user missing language or level preferences")
}
// Get user's daily goal from learning preferences
prefs, perr := s.learningService.GetUserLearningPreferences(ctx, userID)
if perr != nil {
span.RecordError(perr)
return contextutils.WrapError(perr, "failed to get user learning preferences")
}
goal := 10
if prefs != nil && prefs.DailyGoal > 0 {
goal = prefs.DailyGoal
}
// Request more candidates than strictly needed to allow filtering out already-assigned questions
buffer := 10 // request this many extra candidates beyond the user's goal
reqLimit := goal + buffer
// Get adaptive questions using an expanded limit so we can filter and still meet goal
questionsWithStats, err := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, reqLimit)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get adaptive questions for assignment")
}
if len(questionsWithStats) == 0 {
// Gather diagnostics to explain why no questions were available
var candidateIDs []int
candidateCount := 0
totalMatching := 0
if s.questionService != nil {
if candidates, qerr := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, 50); qerr == nil && candidates != nil {
candidateCount = len(candidates)
for i, q := range candidates {
if i >= 10 {
break
}
if q != nil {
candidateIDs = append(candidateIDs, q.ID)
}
}
}
if _, total, terr := s.questionService.GetAllQuestionsPaginated(ctx, 1, 1, "", "", "", language, level, nil); terr == nil {
totalMatching = total
}
}
return &NoQuestionsAvailableError{
Language: language,
Level: level,
CandidateIDs: candidateIDs,
CandidateCount: candidateCount,
TotalMatching: totalMatching,
}
}
// Convert QuestionWithStats to Question for assignment
var questions []models.Question
for _, qws := range questionsWithStats {
questions = append(questions, *qws.Question)
}
// Begin transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error(ctx, "Failed to rollback transaction", rollbackErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}
}()
// First, delete existing assignments for this user and date
deleteQuery := `DELETE FROM daily_question_assignments WHERE user_id = $1 AND assignment_date = $2`
_, err = tx.ExecContext(ctx, deleteQuery, userID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to delete existing assignments")
}
// Insert new assignments
insertQuery := `
INSERT INTO daily_question_assignments (user_id, question_id, assignment_date, created_at)
VALUES ($1, $2, $3, $4)
`
stmt, err := tx.PrepareContext(ctx, insertQuery)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to prepare statement")
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close statement", closeErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}()
// Only assign up to the goal amount
assignedCount := 0
for _, question := range questions {
if assignedCount >= goal {
break
}
_, err = stmt.ExecContext(ctx, userID, question.ID, date, time.Now())
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to insert assignment")
}
assignedCount++
}
// Commit transaction
err = tx.Commit()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to commit transaction")
}
s.logger.Info(ctx, "Daily questions regenerated successfully", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"count": len(questions),
})
return nil
}
// GetDailyQuestions retrieves all daily questions for a user on a specific date
func (s *DailyQuestionService) GetDailyQuestions(ctx context.Context, userID int, date time.Time) (result0 []*models.DailyQuestionAssignmentWithQuestion, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetDailyQuestions",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT dqa.id, dqa.user_id, dqa.question_id, dqa.assignment_date,
dqa.is_completed, dqa.completed_at, dqa.created_at,
dqa.user_answer_index, dqa.submitted_at,
q.id, q.type, q.language, q.level, q.difficulty_score, q.content,
q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario,
q.style_modifier, q.difficulty_modifier, q.time_context,
-- Daily shown count per user: how many times this user has seen this question in Daily across all dates
(SELECT COUNT(*) FROM daily_question_assignments dqa_all WHERE dqa_all.question_id = dqa.question_id AND dqa_all.user_id = dqa.user_id) AS daily_shown_count,
-- Per-user correctness stats across all time
COALESCE((SELECT COUNT(*) FROM user_responses ur WHERE ur.user_id = dqa.user_id AND ur.question_id = dqa.question_id), 0) AS user_total_responses,
COALESCE((SELECT COUNT(*) FROM user_responses ur WHERE ur.user_id = dqa.user_id AND ur.question_id = dqa.question_id AND ur.is_correct = TRUE), 0) AS user_correct_count,
COALESCE((SELECT COUNT(*) FROM user_responses ur WHERE ur.user_id = dqa.user_id AND ur.question_id = dqa.question_id AND ur.is_correct = FALSE), 0) AS user_incorrect_count
FROM daily_question_assignments dqa
JOIN questions q ON dqa.question_id = q.id
WHERE dqa.user_id = $1 AND dqa.assignment_date = $2
ORDER BY dqa.created_at ASC
`
rows, err := s.db.QueryContext(ctx, query, userID, date)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to query daily questions")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}()
var assignments []*models.DailyQuestionAssignmentWithQuestion
for rows.Next() {
var assignment models.DailyQuestionAssignmentWithQuestion
var question models.Question
var contentJSON string
err := rows.Scan(
&assignment.ID, &assignment.UserID, &assignment.QuestionID, &assignment.AssignmentDate,
&assignment.IsCompleted, &assignment.CompletedAt, &assignment.CreatedAt,
&assignment.UserAnswerIndex, &assignment.SubmittedAt,
&question.ID, &question.Type, &question.Language, &question.Level, &question.DifficultyScore,
&contentJSON, &question.CorrectAnswer, &question.Explanation, &question.CreatedAt, &question.Status,
&question.TopicCategory, &question.GrammarFocus, &question.VocabularyDomain, &question.Scenario,
&question.StyleModifier, &question.DifficultyModifier, &question.TimeContext,
&assignment.DailyShownCount,
&assignment.UserTotalResponses,
&assignment.UserCorrectCount,
&assignment.UserIncorrectCount,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan daily question assignment", err, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
continue
}
// Unmarshal the JSON content
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
s.logger.Error(ctx, "Failed to unmarshal question content", err, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"content": contentJSON,
})
continue
}
assignment.Question = &question
assignments = append(assignments, &assignment)
}
if err = rows.Err(); err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "error iterating over rows")
}
return assignments, nil
}
// MarkQuestionCompleted marks a daily question as completed
func (s *DailyQuestionService) MarkQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "MarkQuestionCompleted",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
UPDATE daily_question_assignments
SET is_completed = true, completed_at = $1
WHERE user_id = $2 AND question_id = $3 AND assignment_date = $4
`
result, err := s.db.ExecContext(ctx, query, time.Now(), userID, questionID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to mark question as completed")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrAssignmentNotFound
}
s.logger.Info(ctx, "Question marked as completed", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
})
return nil
}
// ResetQuestionCompleted resets a daily question to not completed
func (s *DailyQuestionService) ResetQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "ResetQuestionCompleted",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
UPDATE daily_question_assignments
SET is_completed = false, completed_at = NULL, user_answer_index = NULL, submitted_at = NULL
WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3
`
result, err := s.db.ExecContext(ctx, query, userID, questionID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to reset question completion")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrAssignmentNotFound
}
s.logger.Info(ctx, "Question reset to not completed", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
})
return nil
}
// GetAvailableDates retrieves all dates for which a user has daily question assignments
func (s *DailyQuestionService) GetAvailableDates(ctx context.Context, userID int) (result0 []time.Time, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetAvailableDates",
trace.WithAttributes(
attribute.Int("user.id", userID),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT DISTINCT assignment_date
FROM daily_question_assignments
WHERE user_id = $1
ORDER BY assignment_date DESC
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to query available dates")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{
"user_id": userID,
})
}
}()
var dates []time.Time
for rows.Next() {
var date time.Time
err := rows.Scan(&date)
if err != nil {
s.logger.Error(ctx, "Failed to scan date", err, map[string]interface{}{
"user_id": userID,
})
continue
}
dates = append(dates, date)
}
if err = rows.Err(); err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "error iterating over rows")
}
return dates, nil
}
// GetDailyProgress retrieves the progress for a specific date
func (s *DailyQuestionService) GetDailyProgress(ctx context.Context, userID int, date time.Time) (result0 *models.DailyProgress, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetDailyProgress",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(*) as total,
COUNT(CASE WHEN is_completed = true THEN 1 END) as completed
FROM daily_question_assignments
WHERE user_id = $1 AND assignment_date = $2
`
var total, completed int
err = s.db.QueryRowContext(ctx, query, userID, date).Scan(&total, &completed)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get daily progress")
}
progress := &models.DailyProgress{
Date: date,
Completed: completed,
Total: total,
}
return progress, nil
}
// GetDailyQuestionsCount retrieves the total number of questions assigned for a date
func (s *DailyQuestionService) GetDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (result0 int, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetDailyQuestionsCount",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(*)
FROM daily_question_assignments
WHERE user_id = $1 AND assignment_date = $2
`
var count int
err = s.db.QueryRowContext(ctx, query, userID, date).Scan(&count)
if err != nil {
return 0, contextutils.WrapError(err, "failed to get daily questions count")
}
return count, nil
}
// GetCompletedDailyQuestionsCount retrieves the number of completed questions for a date
func (s *DailyQuestionService) GetCompletedDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (result0 int, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetCompletedDailyQuestionsCount",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(*)
FROM daily_question_assignments
WHERE user_id = $1 AND assignment_date = $2 AND is_completed = true
`
var count int
err = s.db.QueryRowContext(ctx, query, userID, date).Scan(&count)
if err != nil {
return 0, contextutils.WrapError(err, "failed to get completed daily questions count")
}
return count, nil
}
// GetQuestionHistory retrieves the history of a specific question for a user over a given number of days
func (s *DailyQuestionService) GetQuestionHistory(ctx context.Context, userID, questionID, days int) (result0 []*models.DailyQuestionHistory, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetQuestionHistory",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.Int("days", days),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
if days <= 0 {
return nil, contextutils.ErrorWithContextf("days must be positive")
}
query := `
SELECT dqa.assignment_date, dqa.is_completed, dqa.submitted_at,
ur.is_correct
FROM daily_question_assignments dqa
LEFT JOIN daily_assignment_responses dar ON dar.assignment_id = dqa.id
LEFT JOIN user_responses ur ON ur.id = dar.user_response_id
WHERE dqa.user_id = $1 AND dqa.question_id = $2
AND dqa.assignment_date >= NOW() - INTERVAL '` + fmt.Sprintf("%d days", days) + `'
AND dqa.assignment_date <= CURRENT_DATE + INTERVAL '1 day'
ORDER BY dqa.assignment_date ASC
`
rows, err := s.db.QueryContext(ctx, query, userID, questionID)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to query question history")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"days": days,
})
}
}()
var history []*models.DailyQuestionHistory
for rows.Next() {
var historyEntry models.DailyQuestionHistory
var isCorrect sql.NullBool
err := rows.Scan(
&historyEntry.AssignmentDate,
&historyEntry.IsCompleted,
&historyEntry.SubmittedAt,
&isCorrect,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan question history entry", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"assignment_date": historyEntry.AssignmentDate,
})
continue
}
if isCorrect.Valid {
historyEntry.IsCorrect = &isCorrect.Bool
} else {
historyEntry.IsCorrect = nil
}
history = append(history, &historyEntry)
}
if err = rows.Err(); err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "error iterating over rows")
}
return history, nil
}
// getUserByID is a helper method to get user information
func (s *DailyQuestionService) getUserByID(ctx context.Context, userID int) (*models.User, error) {
query := `
SELECT id, username, email, timezone, password_hash, last_active,
preferred_language, current_level, ai_provider, ai_model,
ai_enabled, ai_api_key, created_at, updated_at
FROM users
WHERE id = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Username, &user.Email, &user.Timezone, &user.PasswordHash,
&user.LastActive, &user.PreferredLanguage, &user.CurrentLevel, &user.AIProvider,
&user.AIModel, &user.AIEnabled, &user.AIAPIKey, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &user, nil
}
// SubmitDailyQuestionAnswer submits an answer for a daily question and marks it as completed
func (s *DailyQuestionService) SubmitDailyQuestionAnswer(ctx context.Context, userID, questionID int, date time.Time, userAnswerIndex int) (result *api.AnswerResponse, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "SubmitDailyQuestionAnswer",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.String("date", date.Format("2006-01-02")),
attribute.Int("user_answer_index", userAnswerIndex),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
s.logger.Info(ctx, "SubmitDailyQuestionAnswer started", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
"user_answer_index": userAnswerIndex,
})
// Check if the question is already answered
s.logger.Info(ctx, "Checking if question is already answered", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
})
query := `
SELECT id, is_completed, user_answer_index, submitted_at
FROM daily_question_assignments
WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3
`
var assignmentID int
var isCompleted bool
var existingUserAnswerIndex *int
var existingSubmittedAt *time.Time
err = s.db.QueryRowContext(ctx, query, userID, questionID, date).Scan(
&assignmentID, &isCompleted, &existingUserAnswerIndex, &existingSubmittedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrAssignmentNotFound
}
return nil, contextutils.WrapError(err, "failed to check question assignment")
}
// Check if already answered
if isCompleted && existingUserAnswerIndex != nil && existingSubmittedAt != nil {
return nil, contextutils.ErrQuestionAlreadyAnswered
}
// Get the question details to validate answer and get correct answer
question, err := s.questionService.GetQuestionByID(ctx, questionID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get question details")
}
if question == nil {
return nil, contextutils.ErrQuestionNotFound
}
// Extract options from content map
contentMap := question.Content
s.logger.Info(ctx, "Question content debug", map[string]interface{}{
"question_id": questionID,
"content_map": contentMap,
})
optionsInterface, ok := contentMap["options"]
if !ok {
s.logger.Error(ctx, "Question content missing options", nil, map[string]interface{}{
"question_id": questionID,
"content_map": contentMap,
})
return nil, contextutils.ErrorWithContextf("question content missing options")
}
options, ok := optionsInterface.([]interface{})
if !ok {
s.logger.Error(ctx, "Invalid options format", nil, map[string]interface{}{
"question_id": questionID,
"options_interface": optionsInterface,
"options_type": fmt.Sprintf("%T", optionsInterface),
})
return nil, contextutils.ErrorWithContextf("invalid options format")
}
// Validate user answer index
if userAnswerIndex < 0 || userAnswerIndex >= len(options) {
return nil, contextutils.ErrInvalidAnswerIndex
}
// Check if answer is correct
isCorrect := question.CorrectAnswer == userAnswerIndex
// Begin transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error(ctx, "Failed to rollback transaction", rollbackErr, map[string]interface{}{
"error": rollbackErr.Error(),
})
}
}
}()
// Update the assignment with the user's answer and mark as completed
updateQuery := `
UPDATE daily_question_assignments
SET is_completed = true, completed_at = NOW(), user_answer_index = $1, submitted_at = NOW()
WHERE id = $2
`
_, err = tx.ExecContext(ctx, updateQuery, userAnswerIndex, assignmentID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to update assignment")
}
// Commit transaction
err = tx.Commit()
if err != nil {
return nil, contextutils.WrapError(err, "failed to commit transaction")
}
// Record canonical user response via learningService so history queries see is_correct
// Use RecordAnswerWithPriorityReturningID to obtain user_responses.id so we can link it to the assignment.
if s.learningService != nil {
// record synchronously so we have the response id for mapping
respID, recErr := s.learningService.RecordAnswerWithPriorityReturningID(ctx, userID, questionID, userAnswerIndex, isCorrect, 0)
if recErr != nil {
s.logger.Error(ctx, "Failed to record user response for daily answer", recErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"user_answer_index": userAnswerIndex,
})
} else {
// Insert mapping to daily_assignment_responses synchronously so tests that run immediately can observe it
_, mapErr := s.db.ExecContext(ctx, `
INSERT INTO daily_assignment_responses (assignment_id, user_response_id, created_at)
VALUES ($1, $2, NOW())
ON CONFLICT (assignment_id) DO UPDATE SET user_response_id = EXCLUDED.user_response_id, created_at = EXCLUDED.created_at
`, assignmentID, respID)
if mapErr != nil {
// Log but don't fail user's request
s.logger.Error(ctx, "Failed to insert daily_assignment_responses mapping", mapErr, map[string]interface{}{
"assignment_id": assignmentID,
"user_response_id": respID,
})
}
// If the answer was correct, remove future assignments for this question within the avoid window
if isCorrect {
// Determine avoidDays via questionService if possible; default to 7
avoidDays := 7
switch qs := s.questionService.(type) {
case interface{ getDailyRepeatAvoidDays() int }:
avoidDays = qs.getDailyRepeatAvoidDays()
default:
// leave default
}
startDate := date.AddDate(0, 0, 1)
endDate := date.AddDate(0, 0, avoidDays)
deleteQuery := `DELETE FROM daily_question_assignments WHERE user_id = $1 AND question_id = $2 AND assignment_date >= $3 AND assignment_date <= $4`
if _, delErr := s.db.ExecContext(ctx, deleteQuery, userID, questionID, startDate, endDate); delErr != nil {
s.logger.Error(ctx, "Failed to delete future daily assignments", delErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"start": startDate,
"end": endDate,
})
} else {
// Future assignments removed successfully; worker will top up missing slots on its next run
s.logger.Info(ctx, "Deleted future daily assignments for question; worker will refill dates as needed", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"start": startDate,
"end": endDate,
})
}
}
}
}
// Build response
userAnswer := options[userAnswerIndex].(string)
response := &api.AnswerResponse{
UserAnswerIndex: &userAnswerIndex,
UserAnswer: &userAnswer,
IsCorrect: &isCorrect,
}
// Add correct answer and explanation if available
response.CorrectAnswerIndex = &question.CorrectAnswer
if question.Explanation != "" {
response.Explanation = &question.Explanation
}
s.logger.Info(ctx, "Daily question answer submitted", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
"user_answer_index": userAnswerIndex,
"is_correct": isCorrect,
})
return response, nil
}
// Package services provides business logic services for the quiz application.
package services
import (
"context"
"database/sql"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services/mailer"
)
// CreateEmailService creates an appropriate email service based on configuration
// If the application is running in test mode, it returns a TestEmailService
// Otherwise, it returns the regular EmailService
func CreateEmailService(cfg *config.Config, logger *observability.Logger) mailer.Mailer {
if cfg.IsTest {
logger.Info(context.Background(), "Using test email service", map[string]interface{}{
"test_mode": true,
})
return NewTestEmailService(cfg, logger)
}
return NewEmailService(cfg, logger)
}
// CreateEmailServiceWithDB creates an appropriate email service with database connection based on configuration
// If the application is running in test mode, it returns a TestEmailService
// Otherwise, it returns the regular EmailService
func CreateEmailServiceWithDB(cfg *config.Config, logger *observability.Logger, db *sql.DB) mailer.Mailer {
if cfg.IsTest {
logger.Info(context.Background(), "Using test email service with DB", map[string]interface{}{
"test_mode": true,
})
return NewTestEmailServiceWithDB(cfg, logger, db)
}
if db == nil {
logger.Error(context.Background(), "Database connection is nil, cannot create EmailService", nil, map[string]interface{}{
"error": "nil_database_connection",
})
panic("EmailService requires a non-nil database connection")
}
return NewEmailServiceWithDB(cfg, logger, db)
}
// Package services provides business logic services for the quiz application.
package services
import (
"context"
"database/sql"
"fmt"
"html/template"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
serviceinterfaces "quizapp/internal/serviceinterfaces"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gopkg.in/mail.v2"
)
// EmailService implements the interfaces.EmailService interface using gomail
type EmailService struct {
cfg *config.Config
logger *observability.Logger
dialer *mail.Dialer
db *sql.DB
}
// EmailServiceInterface defines the interface for email functionality
type EmailServiceInterface = serviceinterfaces.EmailService
// Ensure EmailService implements the EmailServiceInterface
var _ serviceinterfaces.EmailService = (*EmailService)(nil)
// NewEmailService creates a new EmailService instance
func NewEmailService(cfg *config.Config, logger *observability.Logger) *EmailService {
var dialer *mail.Dialer
if cfg.Email.Enabled && cfg.Email.SMTP.Host != "" {
dialer = mail.NewDialer(
cfg.Email.SMTP.Host,
cfg.Email.SMTP.Port,
cfg.Email.SMTP.Username,
cfg.Email.SMTP.Password,
)
}
return &EmailService{
cfg: cfg,
logger: logger,
dialer: dialer,
}
}
// NewEmailServiceWithDB creates a new EmailService instance with database connection
func NewEmailServiceWithDB(cfg *config.Config, logger *observability.Logger, db *sql.DB) *EmailService {
if db == nil {
panic("EmailService requires a non-nil database connection")
}
var dialer *mail.Dialer
if cfg.Email.Enabled && cfg.Email.SMTP.Host != "" {
dialer = mail.NewDialer(
cfg.Email.SMTP.Host,
cfg.Email.SMTP.Port,
cfg.Email.SMTP.Username,
cfg.Email.SMTP.Password,
)
}
return &EmailService{
cfg: cfg,
logger: logger,
dialer: dialer,
db: db,
}
}
// SendDailyReminder sends a daily reminder email to a user
func (e *EmailService) SendDailyReminder(ctx context.Context, user *models.User) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "SendDailyReminder",
trace.WithAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.email", user.Email.String),
),
)
defer observability.FinishSpan(span, &err)
if !e.IsEnabled() {
e.logger.Info(ctx, "Email disabled, skipping daily reminder", map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
return nil
}
if !user.Email.Valid || user.Email.String == "" {
e.logger.Warn(ctx, "User has no email address, skipping daily reminder", map[string]interface{}{
"user_id": user.ID,
})
return nil
}
// Determine daily goal from DB
dailyGoal := 10
var dg sql.NullInt64
if err := e.db.QueryRowContext(ctx, "SELECT daily_goal FROM user_learning_preferences WHERE user_id = $1", user.ID).Scan(&dg); err == nil && dg.Valid {
dailyGoal = int(dg.Int64)
}
// Generate email data
data := map[string]interface{}{
"Username": user.Username,
"QuizAppURL": e.cfg.Server.AppBaseURL, // Frontend app URL for email links
"CurrentDate": time.Now().Format("January 2, 2006"),
"DailyGoal": dailyGoal,
"UnsubscribeURL": fmt.Sprintf("%s/settings", e.cfg.Server.AppBaseURL),
}
subject := "Time for your daily quiz! ð"
err = e.SendEmail(ctx, user.Email.String, subject, "daily_reminder", data)
if err != nil {
return contextutils.WrapError(err, "failed to send daily reminder")
}
e.logger.Info(ctx, "Daily reminder sent successfully", map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
return nil
}
// SendEmail sends a generic email with the given parameters
func (e *EmailService) SendEmail(ctx context.Context, to, subject, templateName string, data map[string]interface{}) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "SendEmail",
trace.WithAttributes(
attribute.String("email.to", to),
attribute.String("email.subject", subject),
attribute.String("email.template", templateName),
),
)
defer observability.FinishSpan(span, &err)
if !e.IsEnabled() {
e.logger.Info(ctx, "Email disabled, skipping email send", map[string]interface{}{
"to": to,
"template": templateName,
})
return nil
}
if e.dialer == nil {
return contextutils.ErrorWithContextf("email service not properly configured")
}
// Create email message
m := mail.NewMessage()
m.SetHeader("From", fmt.Sprintf("%s <%s>", e.cfg.Email.SMTP.FromName, e.cfg.Email.SMTP.FromAddress))
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
// Generate email content from template
content, err := e.generateEmailContent(templateName, data)
if err != nil {
return contextutils.WrapError(err, "failed to generate email content")
}
m.SetBody("text/html", content)
// Send email
if err = e.dialer.DialAndSend(m); err != nil {
e.logger.Error(ctx, "Failed to send email", err, map[string]interface{}{
"to": to,
"template": templateName,
"subject": subject,
})
return contextutils.WrapError(err, "failed to send email")
}
e.logger.Info(ctx, "Email sent successfully", map[string]interface{}{
"to": to,
"template": templateName,
"subject": subject,
})
return nil
}
// RecordSentNotification records a sent notification in the database
func (e *EmailService) RecordSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "RecordSentNotification",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("notification.type", notificationType),
attribute.String("notification.status", status),
),
)
defer observability.FinishSpan(span, &err)
if e.db == nil {
e.logger.Error(ctx, "Database connection is nil, cannot record notification", nil, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
})
return contextutils.ErrorWithContextf("EmailService database connection is nil")
}
query := `
INSERT INTO sent_notifications (user_id, notification_type, subject, template_name, sent_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err = e.db.ExecContext(ctx, query, userID, notificationType, subject, templateName, time.Now(), status, errorMessage)
if err != nil {
e.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return contextutils.WrapError(err, "failed to record sent notification")
}
e.logger.Info(ctx, "Recorded sent notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return nil
}
// IsEnabled returns whether email functionality is enabled
func (e *EmailService) IsEnabled() bool {
return e.cfg.Email.Enabled && e.cfg.Email.SMTP.Host != ""
}
// generateEmailContent generates email content from templates
func (e *EmailService) generateEmailContent(templateName string, data map[string]interface{}) (string, error) {
// For now, we'll use a simple template system
// In a real implementation, you might load templates from files or database
switch templateName {
case "daily_reminder":
return e.generateDailyReminderTemplate(data)
case "test_email":
return e.generateTestEmailTemplate(data)
case "word_of_the_day":
return e.generateWordOfTheDayTemplate(data)
default:
return "", contextutils.ErrorWithContextf("unknown template: %s", templateName)
}
}
// generateDailyReminderTemplate generates the daily reminder email template
func (e *EmailService) generateDailyReminderTemplate(data map[string]interface{}) (string, error) {
const templateStr = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Daily Quiz Reminder</title>
<style>
body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background-color: #4CAF50; color: white; padding: 20px; text-align: center; border-radius: 5px 5px 0 0; }
.content { background-color: #f9f9f9; padding: 20px; }
.button { display: inline-block; background-color: #4CAF50; color: white; padding: 12px 24px; text-decoration: none; border-radius: 5px; margin: 20px 0; }
.footer { background-color: #eee; padding: 15px; text-align: center; font-size: 12px; color: #666; border-radius: 0 0 5px 5px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>ð Daily Quiz Reminder</h1>
</div>
<div class="content">
<h2>Hello {{.Username}}!</h2>
<p>It's {{.CurrentDate}} and time for your daily questions!</p>
<p>Your goal today: <strong>{{.DailyGoal}} questions</strong></p>
<p>Keep up the great work and continue improving your language skills!</p>
<div style="text-align: center;">
<a href="{{.QuizAppURL}}/daily" class="button">Start Your Daily Questions</a>
</div>
</div>
<div class="footer">
<p>This email was sent by Quiz App. If you no longer wish to receive these reminders, you can <a href="{{.UnsubscribeURL}}">unsubscribe here</a>.</p>
</div>
</div>
</body>
</html>`
tmpl, err := template.New("daily_reminder").Parse(templateStr)
if err != nil {
return "", contextutils.WrapError(err, "failed to parse template")
}
var buf strings.Builder
if err := tmpl.Execute(&buf, data); err != nil {
return "", contextutils.WrapError(err, "failed to execute template")
}
return buf.String(), nil
}
// generateTestEmailTemplate generates the test email template
func (e *EmailService) generateTestEmailTemplate(data map[string]interface{}) (string, error) {
const templateStr = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Test Email</title>
<style>
body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background-color: #2196F3; color: white; padding: 20px; text-align: center; border-radius: 5px 5px 0 0; }
.content { background-color: #f9f9f9; padding: 20px; }
.footer { background-color: #eee; padding: 15px; text-align: center; font-size: 12px; color: #666; border-radius: 0 0 5px 5px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>ð Test Email</h1>
</div>
<div class="content">
<h2>Hello {{.Username}}!</h2>
<p>This is a test email to verify that your email settings are working correctly.</p>
<p><strong>Test Time:</strong> {{.TestTime}}</p>
<p><strong>Message:</strong> {{.Message}}</p>
<p>If you received this email, your email configuration is working properly!</p>
</div>
<div class="footer">
<p>This is a test email from Quiz App. No action is required.</p>
</div>
</div>
</body>
</html>
`
tmpl, err := template.New("test_email").Parse(templateStr)
if err != nil {
return "", contextutils.WrapError(err, "failed to parse template")
}
var buf strings.Builder
if err := tmpl.Execute(&buf, data); err != nil {
return "", contextutils.WrapError(err, "failed to execute template")
}
return buf.String(), nil
}
// SendWordOfTheDayEmail sends a word of the day email to a user
func (e *EmailService) SendWordOfTheDayEmail(ctx context.Context, userID int, date time.Time, wordOfTheDay *models.WordOfTheDayDisplay) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "SendWordOfTheDayEmail",
trace.WithAttributes(
attribute.Int("email.user_id", userID),
attribute.String("email.date", date.Format("2006-01-02")),
),
)
defer observability.FinishSpan(span, &err)
if !e.IsEnabled() {
e.logger.Info(ctx, "Email disabled, skipping word of the day email", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
return nil
}
// Get user to check email preferences
user, err := e.getUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user")
}
if user == nil {
return contextutils.ErrorWithContextf("user not found: %d", userID)
}
// Check if user has email disabled for word of the day
if !user.WordOfDayEmailEnabled.Bool {
e.logger.Info(ctx, "User has word of the day emails disabled", map[string]interface{}{
"user_id": userID,
})
return nil
}
if !user.Email.Valid || user.Email.String == "" {
return contextutils.ErrorWithContextf("user has no email address")
}
// Prepare email data
data := map[string]interface{}{
"Username": user.Username,
"Word": wordOfTheDay.Word,
"Translation": wordOfTheDay.Translation,
"Sentence": wordOfTheDay.Sentence,
"Date": date.Format("January 2, 2006"),
"Language": wordOfTheDay.Language,
"Level": wordOfTheDay.Level,
"Explanation": wordOfTheDay.Explanation,
"QuizAppURL": e.cfg.Server.AppBaseURL,
"UnsubscribeURL": fmt.Sprintf("%s/settings?tab=notifications", e.cfg.Server.AppBaseURL),
}
subject := fmt.Sprintf("Word of the Day: %s - %s", wordOfTheDay.Word, date.Format("January 2, 2006"))
return e.SendEmail(ctx, user.Email.String, subject, "word_of_the_day", data)
}
// generateWordOfTheDayTemplate generates the word of the day email template
func (e *EmailService) generateWordOfTheDayTemplate(data map[string]interface{}) (string, error) {
const templateStr = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Word of the Day</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; line-height: 1.6; color: #333; margin: 0; padding: 0; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px 20px; text-align: center; border-radius: 8px 8px 0 0; }
.content { background-color: #ffffff; padding: 30px; border: 1px solid #e0e0e0; border-top: none; }
.date { color: #667eea; font-size: 14px; font-weight: 600; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 15px; }
.word { font-size: 48px; font-weight: bold; color: #1a1a1a; margin-bottom: 15px; line-height: 1.2; }
.translation { font-size: 24px; color: #667eea; margin-bottom: 25px; font-style: italic; }
.sentence { font-size: 18px; line-height: 1.8; color: #555; background: #f7f7f7; padding: 25px; border-radius: 8px; border-left: 4px solid #667eea; margin-bottom: 20px; font-style: italic; }
.explanation { font-size: 15px; color: #666; margin-top: 20px; padding: 20px; background: #fafafa; border-radius: 8px; border-left: 3px solid #764ba2; }
.meta { display: flex; gap: 10px; flex-wrap: wrap; margin-top: 20px; }
.badge { background: #e0e7ff; color: #667eea; padding: 6px 12px; border-radius: 20px; font-size: 12px; font-weight: 600; }
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 14px 28px; text-decoration: none; border-radius: 6px; margin: 20px 0; font-weight: 600; }
.footer { background-color: #f5f5f5; padding: 20px; text-align: center; font-size: 12px; color: #666; border-radius: 0 0 8px 8px; border: 1px solid #e0e0e0; border-top: none; }
.footer a { color: #667eea; text-decoration: none; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1 style="margin: 0; font-size: 28px;">ð Word of the Day</h1>
</div>
<div class="content">
<div class="date">{{.Date}}</div>
<div class="word">{{.Word}}</div>
<div class="translation">{{.Translation}}</div>
{{if .Sentence}}
<div class="sentence">{{.Sentence}}</div>
{{end}}
{{if .Explanation}}
<div class="explanation">{{.Explanation}}</div>
{{end}}
<div class="meta">
{{if .Language}}<span class="badge">{{.Language}}</span>{{end}}
{{if .Level}}<span class="badge">{{.Level}}</span>{{end}}
</div>
<div style="text-align: center; margin-top: 30px;">
<a href="{{.QuizAppURL}}/word-of-day" class="button">View in App</a>
</div>
</div>
<div class="footer">
<p>This email was sent by Quiz App. If you no longer wish to receive word of the day emails, you can <a href="{{.UnsubscribeURL}}">update your preferences here</a>.</p>
</div>
</div>
</body>
</html>`
tmpl, err := template.New("word_of_the_day").Parse(templateStr)
if err != nil {
return "", contextutils.WrapError(err, "failed to parse template")
}
var buf strings.Builder
if err := tmpl.Execute(&buf, data); err != nil {
return "", contextutils.WrapError(err, "failed to execute template")
}
return buf.String(), nil
}
// getUserByID retrieves a user by ID (helper method)
func (e *EmailService) getUserByID(ctx context.Context, userID int) (*models.User, error) {
if e.db == nil {
return nil, contextutils.ErrorWithContextf("database connection not available")
}
query := `
SELECT id, username, email, word_of_day_email_enabled
FROM users
WHERE id = $1
`
var user models.User
err := e.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID,
&user.Username,
&user.Email,
&user.WordOfDayEmailEnabled,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to query user")
}
return &user, nil
}
package services
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
)
// FeedbackService implements FeedbackServiceInterface for managing feedback reports.
type FeedbackService struct {
db *sql.DB
logger *observability.Logger
}
// NewFeedbackService creates a new FeedbackService instance.
func NewFeedbackService(db *sql.DB, logger *observability.Logger) *FeedbackService {
if db == nil {
panic("NewFeedbackService: db is nil")
}
if logger == nil {
panic("NewFeedbackService: logger is nil")
}
return &FeedbackService{db: db, logger: logger}
}
// CreateFeedback inserts a new feedback report.
func (s *FeedbackService) CreateFeedback(ctx context.Context, fr *models.FeedbackReport) (result0 *models.FeedbackReport, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_feedback")
defer observability.FinishSpan(span, &err)
contextJSON, err := json.Marshal(fr.ContextData)
if err != nil {
return nil, contextutils.WrapError(err, "failed to marshal context_data")
}
query := `INSERT INTO feedback_reports (user_id, feedback_text, feedback_type, context_data, screenshot_data, screenshot_url, status, created_at, updated_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9) RETURNING id, created_at, updated_at`
now := time.Now()
var id int
var createdAt, updatedAt time.Time
err = s.db.QueryRowContext(ctx, query, fr.UserID, fr.FeedbackText, fr.FeedbackType, contextJSON, fr.ScreenshotData, fr.ScreenshotURL, "new", now, now).
Scan(&id, &createdAt, &updatedAt)
if err != nil {
return nil, contextutils.WrapError(err, "failed to insert feedback report")
}
fr.ID = id
fr.Status = "new"
fr.CreatedAt = createdAt
fr.UpdatedAt = updatedAt
return fr, nil
}
// GetFeedbackByID fetches single feedback.
func (s *FeedbackService) GetFeedbackByID(ctx context.Context, id int) (result0 *models.FeedbackReport, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_feedback_by_id")
defer observability.FinishSpan(span, &err)
query := `SELECT id, user_id, feedback_text, feedback_type, context_data, screenshot_data, screenshot_url, status, admin_notes, assigned_to_user_id, resolved_at, resolved_by_user_id, created_at, updated_at FROM feedback_reports WHERE id=$1`
row := s.db.QueryRowContext(ctx, query, id)
var fr models.FeedbackReport
var contextJSON []byte
err = row.Scan(&fr.ID, &fr.UserID, &fr.FeedbackText, &fr.FeedbackType, &contextJSON, &fr.ScreenshotData, &fr.ScreenshotURL, &fr.Status, &fr.AdminNotes, &fr.AssignedToUserID, &fr.ResolvedAt, &fr.ResolvedByUserID, &fr.CreatedAt, &fr.UpdatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrRecordNotFound
}
return nil, contextutils.WrapError(err, "failed to scan feedback")
}
_ = json.Unmarshal(contextJSON, &fr.ContextData)
return &fr, nil
}
// GetFeedbackPaginated returns list of feedback reports with filters.
func (s *FeedbackService) GetFeedbackPaginated(ctx context.Context, page, pageSize int, status, feedbackType string, userID *int) (result0 []models.FeedbackReport, result1 int, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_feedback_paginated")
defer observability.FinishSpan(span, &err)
var conditions []string
var args []interface{}
idx := 1
if status != "" {
conditions = append(conditions, fmt.Sprintf("status=$%d", idx))
args = append(args, status)
idx++
}
if feedbackType != "" {
conditions = append(conditions, fmt.Sprintf("feedback_type=$%d", idx))
args = append(args, feedbackType)
idx++
}
if userID != nil {
conditions = append(conditions, fmt.Sprintf("user_id=$%d", idx))
args = append(args, *userID)
idx++
}
where := ""
if len(conditions) > 0 {
where = "WHERE " + strings.Join(conditions, " AND ")
}
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM feedback_reports %s", where)
var total int
if err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count feedback")
}
offset := (page - 1) * pageSize
args = append(args, pageSize, offset)
query := fmt.Sprintf("SELECT id, user_id, feedback_text, feedback_type, context_data, screenshot_data, screenshot_url, status, admin_notes, assigned_to_user_id, resolved_at, resolved_by_user_id, created_at, updated_at FROM feedback_reports %s ORDER BY created_at DESC LIMIT $%d OFFSET $%d", where, idx, idx+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to query feedback list")
}
defer func() {
_ = rows.Close()
}()
list := []models.FeedbackReport{}
for rows.Next() {
var fr models.FeedbackReport
var contextJSON []byte
if err := rows.Scan(&fr.ID, &fr.UserID, &fr.FeedbackText, &fr.FeedbackType, &contextJSON, &fr.ScreenshotData, &fr.ScreenshotURL, &fr.Status, &fr.AdminNotes, &fr.AssignedToUserID, &fr.ResolvedAt, &fr.ResolvedByUserID, &fr.CreatedAt, &fr.UpdatedAt); err != nil {
return nil, 0, contextutils.WrapError(err, "scan feedback list")
}
_ = json.Unmarshal(contextJSON, &fr.ContextData)
list = append(list, fr)
}
return list, total, nil
}
// UpdateFeedback allows status/notes assignment updates.
func (s *FeedbackService) UpdateFeedback(ctx context.Context, id int, updates map[string]interface{}) (result0 *models.FeedbackReport, err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_feedback", attribute.Int("feedback.id", id))
defer observability.FinishSpan(span, &err)
if len(updates) == 0 {
return s.GetFeedbackByID(ctx, id)
}
var sets []string
var args []interface{}
idx := 1
for k, v := range updates {
sets = append(sets, fmt.Sprintf("%s=$%d", k, idx))
args = append(args, v)
idx++
}
sets = append(sets, fmt.Sprintf("updated_at=$%d", idx))
args = append(args, time.Now())
args = append(args, id)
query := fmt.Sprintf("UPDATE feedback_reports SET %s WHERE id=$%d", strings.Join(sets, ","), idx+1)
if _, err := s.db.ExecContext(ctx, query, args...); err != nil {
return nil, contextutils.WrapError(err, "failed to update feedback")
}
return s.GetFeedbackByID(ctx, id)
}
// DeleteFeedback deletes a single feedback report by ID.
func (s *FeedbackService) DeleteFeedback(ctx context.Context, id int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_feedback", attribute.Int("feedback.id", id))
defer observability.FinishSpan(span, &err)
query := `DELETE FROM feedback_reports WHERE id=$1`
result, err := s.db.ExecContext(ctx, query, id)
if err != nil {
return contextutils.WrapError(err, "failed to delete feedback")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "feedback with ID %d not found", id)
}
return nil
}
// DeleteFeedbackByStatus deletes all feedback reports with a specific status.
func (s *FeedbackService) DeleteFeedbackByStatus(ctx context.Context, status string) (result0 int, err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_feedback_by_status", attribute.String("status", status))
defer observability.FinishSpan(span, &err)
query := `DELETE FROM feedback_reports WHERE status=$1`
result, err := s.db.ExecContext(ctx, query, status)
if err != nil {
return 0, contextutils.WrapError(err, "failed to delete feedback by status")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, contextutils.WrapError(err, "failed to get rows affected")
}
return int(rowsAffected), nil
}
// DeleteAllFeedback deletes all feedback reports regardless of status.
func (s *FeedbackService) DeleteAllFeedback(ctx context.Context) (result0 int, err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_all_feedback")
defer observability.FinishSpan(span, &err)
query := `DELETE FROM feedback_reports`
result, err := s.db.ExecContext(ctx, query)
if err != nil {
return 0, contextutils.WrapError(err, "failed to delete all feedback")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return 0, contextutils.WrapError(err, "failed to get rows affected")
}
return int(rowsAffected), nil
}
package services
import (
"context"
"database/sql"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
)
// GenerationHint represents an active generation hint
type GenerationHint struct {
ID int `db:"id"`
UserID int `db:"user_id"`
Language string `db:"language"`
Level string `db:"level"`
QuestionType string `db:"question_type"`
PriorityWeight int `db:"priority_weight"`
ExpiresAt time.Time `db:"expires_at"`
CreatedAt time.Time `db:"created_at"`
}
// GenerationHintServiceInterface defines the API for managing generation hints
type GenerationHintServiceInterface interface {
UpsertHint(ctx context.Context, userID int, language, level string, qType models.QuestionType, ttl time.Duration) error
GetActiveHintsForUser(ctx context.Context, userID int) ([]GenerationHint, error)
ClearHint(ctx context.Context, userID int, language, level string, qType models.QuestionType) error
}
// GenerationHintService implements hint management
type GenerationHintService struct {
db *sql.DB
logger *observability.Logger
}
// NewGenerationHintService constructs a service for managing short-lived per-user
// generation hints that nudge the worker to prioritize specific question types
// (e.g., reading comprehension) when the user is waiting for generation.
func NewGenerationHintService(db *sql.DB, logger *observability.Logger) *GenerationHintService {
return &GenerationHintService{db: db, logger: logger}
}
// UpsertHint creates or refreshes a hint with the given TTL
func (s *GenerationHintService) UpsertHint(ctx context.Context, userID int, language, level string, qType models.QuestionType, ttl time.Duration) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "upsert_generation_hint")
defer observability.FinishSpan(span, &err)
expiresAt := time.Now().Add(ttl)
_, err = s.db.ExecContext(ctx, `
INSERT INTO generation_hints (user_id, language, level, question_type, priority_weight, expires_at)
VALUES ($1, $2, $3, $4, 1, $5)
ON CONFLICT (user_id, language, level, question_type) DO UPDATE SET
priority_weight = generation_hints.priority_weight + 1,
expires_at = EXCLUDED.expires_at,
created_at = generation_hints.created_at
`, userID, language, level, string(qType), expiresAt)
if err != nil {
return contextutils.WrapError(err, "failed to upsert generation hint")
}
return nil
}
// GetActiveHintsForUser returns non-expired hints for the user
func (s *GenerationHintService) GetActiveHintsForUser(ctx context.Context, userID int) (result0 []GenerationHint, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_active_generation_hints")
defer observability.FinishSpan(span, &err)
rows, err := s.db.QueryContext(ctx, `
SELECT id, user_id, language, level, question_type, priority_weight, expires_at, created_at
FROM generation_hints
WHERE user_id = $1 AND expires_at > NOW()
ORDER BY created_at ASC
`, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query generation hints")
}
defer func() { _ = rows.Close() }()
var hints []GenerationHint
for rows.Next() {
var h GenerationHint
if err := rows.Scan(&h.ID, &h.UserID, &h.Language, &h.Level, &h.QuestionType, &h.PriorityWeight, &h.ExpiresAt, &h.CreatedAt); err != nil {
return nil, contextutils.WrapError(err, "failed to scan generation hint")
}
hints = append(hints, h)
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating generation hints")
}
return hints, nil
}
// ClearHint deletes a specific hint
func (s *GenerationHintService) ClearHint(ctx context.Context, userID int, language, level string, qType models.QuestionType) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "clear_generation_hint")
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
DELETE FROM generation_hints
WHERE user_id = $1 AND language = $2 AND level = $3 AND question_type = $4
`, userID, language, level, string(qType))
if err != nil {
return contextutils.WrapError(err, "failed to clear generation hint")
}
return nil
}
package services
import (
"context"
"database/sql"
"fmt"
"math"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/lib/pq"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// LearningServiceInterface defines the interface for the learning service
type LearningServiceInterface interface {
RecordUserResponse(ctx context.Context, response *models.UserResponse) error
GetUserProgress(ctx context.Context, userID int) (*models.UserProgress, error)
GetWeakestTopics(ctx context.Context, userID, limit int) ([]*models.PerformanceMetrics, error)
ShouldAvoidQuestion(ctx context.Context, userID, questionID int) (bool, error)
GetUserQuestionStats(ctx context.Context, userID int) (*UserQuestionStats, error)
// Priority system methods
RecordAnswerWithPriority(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) error
// RecordAnswerWithPriorityReturningID records the response and returns the created user_responses.id
RecordAnswerWithPriorityReturningID(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) (int, error)
MarkQuestionAsKnown(ctx context.Context, userID, questionID int, confidenceLevel *int) error
GetUserLearningPreferences(ctx context.Context, userID int) (*models.UserLearningPreferences, error)
UpdateLastDailyReminderSent(ctx context.Context, userID int) error
CalculatePriorityScore(ctx context.Context, userID, questionID int) (float64, error)
UpdateUserLearningPreferences(ctx context.Context, userID int, prefs *models.UserLearningPreferences) (*models.UserLearningPreferences, error)
GetUserQuestionConfidenceLevel(ctx context.Context, userID, questionID int) (*int, error)
// Analytics methods
GetPriorityScoreDistribution(ctx context.Context) (map[string]interface{}, error)
GetHighPriorityQuestions(ctx context.Context, limit int) ([]map[string]interface{}, error)
GetWeakAreasByTopic(ctx context.Context, limit int) ([]map[string]interface{}, error)
GetLearningPreferencesUsage(ctx context.Context) (map[string]interface{}, error)
GetQuestionTypeGaps(ctx context.Context) ([]map[string]interface{}, error)
GetGenerationSuggestions(ctx context.Context) ([]map[string]interface{}, error)
GetPrioritySystemPerformance(ctx context.Context) (map[string]interface{}, error)
GetBackgroundJobsStatus(ctx context.Context) (map[string]interface{}, error)
// User-specific analytics methods
GetUserPriorityScoreDistribution(ctx context.Context, userID int) (map[string]interface{}, error)
GetUserHighPriorityQuestions(ctx context.Context, userID, limit int) ([]map[string]interface{}, error)
GetUserWeakAreas(ctx context.Context, userID, limit int) ([]map[string]interface{}, error)
// Additional analytics methods for progress API
GetHighPriorityTopics(ctx context.Context, userID int) ([]string, error)
GetGapAnalysis(ctx context.Context, userID int) (map[string]interface{}, error)
GetPriorityDistribution(ctx context.Context, userID int) (map[string]int, error)
}
// UserQuestionStats represents per-user question statistics
type UserQuestionStats struct {
UserID int `json:"user_id"`
TotalAnswered int `json:"total_answered"`
CorrectAnswers int `json:"correct_answers"`
IncorrectAnswers int `json:"incorrect_answers"`
AccuracyRate float64 `json:"accuracy_rate"`
AnsweredByType map[string]int `json:"answered_by_type"`
AnsweredByLevel map[string]int `json:"answered_by_level"`
AccuracyByType map[string]float64 `json:"accuracy_by_type"`
AccuracyByLevel map[string]float64 `json:"accuracy_by_level"`
AvailableByType map[string]int `json:"available_by_type"`
AvailableByLevel map[string]int `json:"available_by_level"`
RecentlyAnswered int `json:"recently_answered"` // Within last hour
}
// contextutils.ErrQuestionNotFound is returned when a question does not exist in the database
// contextutils.ErrQuestionNotFound is now imported from contextutils
// LearningService provides methods for managing user learning progress
type LearningService struct {
db *sql.DB
cfg *config.Config
logger *observability.Logger
}
// NewLearningServiceWithLogger creates a new LearningService with a logger
func NewLearningServiceWithLogger(db *sql.DB, cfg *config.Config, logger *observability.Logger) *LearningService {
return &LearningService{
db: db,
cfg: cfg,
logger: logger,
}
}
// RecordUserResponse records a user's response to a question and updates metrics
func (s *LearningService) RecordUserResponse(ctx context.Context, response *models.UserResponse) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "record_user_response",
observability.AttributeUserID(response.UserID),
observability.AttributeQuestionID(response.QuestionID),
attribute.Bool("response.is_correct", response.IsCorrect),
attribute.Int("response.time_ms", response.ResponseTimeMs),
)
defer observability.FinishSpan(span, &err)
query := `
INSERT INTO user_responses (user_id, question_id, user_answer_index, is_correct, response_time_ms)
VALUES ($1, $2, $3, $4, $5) RETURNING id
`
var id int
err = s.db.QueryRowContext(ctx, query,
response.UserID,
response.QuestionID,
response.UserAnswerIndex,
response.IsCorrect,
response.ResponseTimeMs,
).Scan(&id)
if err != nil {
return err
}
response.ID = id
// Update performance metrics
return s.updatePerformanceMetrics(ctx, response)
}
func (s *LearningService) updatePerformanceMetrics(ctx context.Context, response *models.UserResponse) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "update_performance_metrics",
observability.AttributeUserID(response.UserID),
observability.AttributeQuestionID(response.QuestionID),
attribute.Bool("response.is_correct", response.IsCorrect),
)
defer observability.FinishSpan(span, &err)
// Get question details
var question *models.Question
question, err = s.getQuestionDetails(ctx, response.QuestionID)
if err != nil {
return err
}
// Update or create performance metrics
query := `
INSERT INTO performance_metrics (
user_id, topic, language, level, total_attempts, correct_attempts,
average_response_time_ms, difficulty_adjustment, last_updated
)
VALUES ($1, $2, $3, $4, 1, $5, $6, 0.0, CURRENT_TIMESTAMP)
ON CONFLICT(user_id, topic, language, level) DO UPDATE SET
total_attempts = performance_metrics.total_attempts + 1,
correct_attempts = performance_metrics.correct_attempts + $7,
average_response_time_ms = (performance_metrics.average_response_time_ms * (performance_metrics.total_attempts - 1) + $8) / performance_metrics.total_attempts,
last_updated = CURRENT_TIMESTAMP
`
correctIncrement := 0
if response.IsCorrect {
correctIncrement = 1
}
_, err = s.db.ExecContext(ctx, query,
response.UserID,
question.TopicCategory,
question.Language,
question.Level,
correctIncrement, // For initial correct_attempts in VALUES
float64(response.ResponseTimeMs), // For initial average_response_time_ms in VALUES
correctIncrement, // For correct_attempts increment in UPDATE
response.ResponseTimeMs, // For average_response_time_ms calculation in UPDATE
)
return err
}
// getUserByID is a lightweight helper for LearningService to fetch a user row.
func (s *LearningService) getUserByID(ctx context.Context, userID int) (*models.User, error) {
query := `
SELECT id, username, email, timezone, password_hash, last_active,
preferred_language, current_level, ai_provider, ai_model,
ai_enabled, ai_api_key, created_at, updated_at
FROM users
WHERE id = $1
`
var u models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&u.ID, &u.Username, &u.Email, &u.Timezone, &u.PasswordHash, &u.LastActive,
&u.PreferredLanguage, &u.CurrentLevel, &u.AIProvider, &u.AIModel,
&u.AIEnabled, &u.AIAPIKey, &u.CreatedAt, &u.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &u, nil
}
func (s *LearningService) getQuestionDetails(ctx context.Context, questionID int) (result0 *models.Question, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_question_details",
observability.AttributeQuestionID(questionID),
)
defer observability.FinishSpan(span, &err)
query := `SELECT type, language, level, topic_category FROM questions WHERE id = $1`
question := &models.Question{}
var topicCategory sql.NullString
err = s.db.QueryRowContext(ctx, query, questionID).Scan(
&question.Type,
&question.Language,
&question.Level,
&topicCategory,
)
if topicCategory.Valid {
question.TopicCategory = topicCategory.String
}
return question, err
}
// GetUserProgress retrieves comprehensive learning progress for a user
func (s *LearningService) GetUserProgress(ctx context.Context, userID int) (result0 *models.UserProgress, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_progress",
attribute.String("user.username", ""),
attribute.String("language", ""),
attribute.String("level", ""),
)
defer observability.FinishSpan(span, &err)
progress := &models.UserProgress{
PerformanceByTopic: make(map[string]*models.PerformanceMetrics),
}
// Get overall stats
overallQuery := `
SELECT
COUNT(*) as total,
COALESCE(SUM(CASE WHEN is_correct THEN 1 ELSE 0 END), 0) as correct
FROM user_responses
WHERE user_id = $1
`
err = s.db.QueryRowContext(ctx, overallQuery, userID).Scan(
&progress.TotalQuestions,
&progress.CorrectAnswers,
)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
if progress.TotalQuestions > 0 {
progress.AccuracyRate = float64(progress.CorrectAnswers) / float64(progress.TotalQuestions) * 100
}
// Get performance by topic
metricsQuery := `
SELECT id, topic, language, level, total_attempts, correct_attempts,
average_response_time_ms, difficulty_adjustment, last_updated
FROM performance_metrics
WHERE user_id = $1
`
rows, err := s.db.QueryContext(ctx, metricsQuery, userID)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
for rows.Next() {
metric := &models.PerformanceMetrics{UserID: userID}
err = rows.Scan(
&metric.ID,
&metric.Topic,
&metric.Language,
&metric.Level,
&metric.TotalAttempts,
&metric.CorrectAttempts,
&metric.AverageResponseTimeMs,
&metric.DifficultyAdjustment,
&metric.LastUpdated,
)
if err != nil {
return nil, err
}
key := metric.Topic + "_" + metric.Language + "_" + metric.Level
progress.PerformanceByTopic[key] = metric
}
// Identify weak areas (accuracy < 60%)
progress.WeakAreas = s.identifyWeakAreas(progress.PerformanceByTopic)
// Get recent activity
progress.RecentActivity, err = s.getRecentActivity(ctx, userID, 10)
if err != nil {
return nil, err
}
// Get current level from user
currentLevel, err := s.getCurrentUserLevel(ctx, userID)
if err != nil {
return nil, err
}
progress.CurrentLevel = currentLevel
// Suggest level adjustment if needed
progress.SuggestedLevel = s.suggestLevelAdjustment(progress)
return progress, nil
}
func (s *LearningService) identifyWeakAreas(metrics map[string]*models.PerformanceMetrics) []string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
// But we could add tracing if we want to track the analysis performance
var weakAreas []string
for key, metric := range metrics {
if metric.TotalAttempts > 0 && metric.AccuracyRate() < 60.0 && metric.TotalAttempts >= 3 {
weakAreas = append(weakAreas, key)
}
}
return weakAreas
}
func (s *LearningService) getRecentActivity(ctx context.Context, userID, limit int) (result0 []models.UserResponse, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_recent_activity",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, user_id, question_id, user_answer_index, is_correct, response_time_ms, created_at
FROM user_responses
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var responses []models.UserResponse
for rows.Next() {
var response models.UserResponse
err = rows.Scan(
&response.ID,
&response.UserID,
&response.QuestionID,
&response.UserAnswerIndex,
&response.IsCorrect,
&response.ResponseTimeMs,
&response.CreatedAt,
)
if err != nil {
return nil, err
}
responses = append(responses, response)
}
return responses, nil
}
func (s *LearningService) getCurrentUserLevel(ctx context.Context, userID int) (result0 string, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_current_user_level",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
query := `SELECT current_level FROM users WHERE id = $1`
var level sql.NullString
err = s.db.QueryRowContext(ctx, query, userID).Scan(&level)
if err != nil {
return "", err
}
// Return default level if NULL
if !level.Valid || level.String == "" {
return "A1", nil // Default level
}
return level.String, nil
}
func (s *LearningService) suggestLevelAdjustment(progress *models.UserProgress) string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
// But we could add tracing if we want to track the analysis performance
if progress.TotalQuestions < 20 {
return "" // Not enough data
}
// If accuracy is consistently high (>85%), suggest level up
if progress.AccuracyRate > 85.0 {
return s.getNextLevel(progress.CurrentLevel)
}
// If accuracy is consistently low (<50%), suggest level down
if progress.AccuracyRate < 50.0 {
return s.getPreviousLevel(progress.CurrentLevel)
}
return ""
}
func (s *LearningService) getNextLevel(currentLevel string) string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
levels := s.cfg.GetAllLevels()
for i, level := range levels {
if level == currentLevel && i < len(levels)-1 {
return levels[i+1]
}
}
return currentLevel
}
func (s *LearningService) getPreviousLevel(currentLevel string) string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
levels := s.cfg.GetAllLevels()
for i, level := range levels {
if level == currentLevel && i > 0 {
return levels[i-1]
}
}
return currentLevel
}
// GetWeakestTopics returns the topics where the user performs poorest
func (s *LearningService) GetWeakestTopics(ctx context.Context, userID, limit int) (result0 []*models.PerformanceMetrics, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_weakest_topics",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, topic, language, level, total_attempts, correct_attempts, average_response_time_ms, difficulty_adjustment, last_updated
FROM performance_metrics
WHERE user_id = $1 AND total_attempts >= 3
ORDER BY (correct_attempts * 1.0 / total_attempts) ASC, last_updated ASC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var topics []*models.PerformanceMetrics
for rows.Next() {
metric := &models.PerformanceMetrics{UserID: userID}
err = rows.Scan(
&metric.ID,
&metric.Topic,
&metric.Language,
&metric.Level,
&metric.TotalAttempts,
&metric.CorrectAttempts,
&metric.AverageResponseTimeMs,
&metric.DifficultyAdjustment,
&metric.LastUpdated,
)
if err != nil {
return nil, err
}
topics = append(topics, metric)
}
return topics, nil
}
// ShouldAvoidQuestion determines if a question should be avoided for a user
func (s *LearningService) ShouldAvoidQuestion(ctx context.Context, userID, questionID int) (result0 bool, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "should_avoid_question",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer observability.FinishSpan(span, &err)
// Determine user's local 1-day window and convert to UTC timestamps
startUTC, endUTC, _, err := contextutils.UserLocalDayRange(ctx, userID, 1, s.getUserByID)
if err != nil {
return false, contextutils.WrapError(err, "failed to compute user local day range")
}
query := `
SELECT COUNT(*)
FROM user_responses
WHERE user_id = $1 AND question_id = $2 AND is_correct = true
AND created_at >= $3 AND created_at < $4
`
var count int
err = s.db.QueryRowContext(ctx, query, userID, questionID, startUTC, endUTC).Scan(&count)
span.SetAttributes(attribute.Bool("should_avoid", count > 0))
return count > 0, err
}
// GetUserQuestionStats returns comprehensive per-user question statistics
func (s *LearningService) GetUserQuestionStats(ctx context.Context, userID int) (result0 *UserQuestionStats, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_question_stats",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
stats := &UserQuestionStats{
UserID: userID,
AnsweredByType: make(map[string]int),
AnsweredByLevel: make(map[string]int),
AccuracyByType: make(map[string]float64),
AccuracyByLevel: make(map[string]float64),
AvailableByType: make(map[string]int),
AvailableByLevel: make(map[string]int),
}
// Get user's language and level preferences
var userLanguage, userLevel string
userQuery := `SELECT COALESCE(preferred_language, 'italian'), COALESCE(current_level, 'B1') FROM users WHERE id = $1`
err = s.db.QueryRowContext(ctx, userQuery, userID).Scan(&userLanguage, &userLevel)
if err != nil {
return nil, err
}
span.SetAttributes(
attribute.String("user.language", userLanguage),
attribute.String("user.level", userLevel),
)
// Get questions answered by user with stats
answeredQuery := `
SELECT
q.type,
q.level,
COUNT(*) as total,
SUM(CASE WHEN ur.is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
WHERE ur.user_id = $1
GROUP BY q.type, q.level
`
rows, err := s.db.QueryContext(ctx, answeredQuery, userID)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
for rows.Next() {
var qType, level string
var total, correct int
if err := rows.Scan(&qType, &level, &total, &correct); err != nil {
return nil, err
}
stats.AnsweredByType[qType] += total
stats.AnsweredByLevel[level] += total
stats.TotalAnswered += total
// Calculate accuracy rates
accuracy := float64(correct) / float64(total) * 100
// For type accuracy, we need to aggregate across levels
if _, exists := stats.AnsweredByType[qType]; exists {
// Recalculate accuracy for this type
typeQuery := `
SELECT
COUNT(*) as total,
SUM(CASE WHEN ur.is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
WHERE ur.user_id = $1 AND q.type = $2
`
var typeTotal, typeCorrect int
if err := s.db.QueryRowContext(ctx, typeQuery, userID, qType).Scan(&typeTotal, &typeCorrect); err != nil {
s.logger.Warn(ctx, "Failed to scan type query result", map[string]interface{}{"error": err.Error()})
}
if typeTotal > 0 {
stats.AccuracyByType[qType] = float64(typeCorrect) / float64(typeTotal) * 100
}
} else {
stats.AccuracyByType[qType] = accuracy
}
// For level accuracy
if _, exists := stats.AnsweredByLevel[level]; exists {
// Recalculate accuracy for this level
levelQuery := `
SELECT
COUNT(*) as total,
SUM(CASE WHEN ur.is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
WHERE ur.user_id = $1 AND q.level = $2
`
var levelTotal, levelCorrect int
if err := s.db.QueryRowContext(ctx, levelQuery, userID, level).Scan(&levelTotal, &levelCorrect); err != nil {
s.logger.Warn(ctx, "Failed to scan level query result", map[string]interface{}{"error": err.Error()})
}
if levelTotal > 0 {
stats.AccuracyByLevel[level] = float64(levelCorrect) / float64(levelTotal) * 100
}
} else {
stats.AccuracyByLevel[level] = accuracy
}
}
// Get available questions (not answered by user) that belong to this user
availableQuery := `
SELECT
q.type,
q.level,
COUNT(*) as available
FROM questions q
JOIN user_questions uq ON uq.question_id = q.id
WHERE uq.user_id = $1
AND q.language = $2
AND q.status = 'active'
AND q.id NOT IN (
SELECT DISTINCT question_id
FROM user_responses
WHERE user_id = $3
)
GROUP BY q.type, q.level
`
rows, err = s.db.QueryContext(ctx, availableQuery, userID, userLanguage, userID)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
for rows.Next() {
var qType, level string
var available int
if err := rows.Scan(&qType, &level, &available); err != nil {
return nil, err
}
stats.AvailableByType[qType] += available
stats.AvailableByLevel[level] += available
}
// Get recently answered questions (within last hour)
recentQuery := `
SELECT COUNT(*)
FROM user_responses ur
WHERE ur.user_id = $1
AND ur.created_at > NOW() - INTERVAL '1 hour'
`
err = s.db.QueryRowContext(ctx, recentQuery, userID).Scan(&stats.RecentlyAnswered)
if err != nil {
stats.RecentlyAnswered = 0 // Default to 0 if query fails
}
// Calculate overall correct/incorrect answers and accuracy rate
overallQuery := `
SELECT
COUNT(*) as total,
SUM(CASE WHEN is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses
WHERE user_id = $1
`
var total, correct int
err = s.db.QueryRowContext(ctx, overallQuery, userID).Scan(&total, &correct)
if err != nil {
// Default values if query fails
stats.CorrectAnswers = 0
stats.IncorrectAnswers = 0
stats.AccuracyRate = 0.0
} else {
stats.CorrectAnswers = correct
stats.IncorrectAnswers = total - correct
if total > 0 {
stats.AccuracyRate = float64(correct) / float64(total) * 100
} else {
stats.AccuracyRate = 0.0
}
}
return stats, nil
}
// PRIORITY SYSTEM METHODS
// RecordAnswerWithPriority records a user's response and updates priority scores
func (s *LearningService) RecordAnswerWithPriority(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) error {
// Create user response object
response := &models.UserResponse{
UserID: userID,
QuestionID: questionID,
UserAnswerIndex: answerIndex,
IsCorrect: isCorrect,
ResponseTimeMs: responseTime,
CreatedAt: time.Now(),
}
// Use existing RecordUserResponse method
err := s.RecordUserResponse(ctx, response)
if err != nil {
return contextutils.WrapError(err, "failed to record user response")
}
// Update priority score in background
go s.updatePriorityScoreAsync(ctx, userID, questionID)
return nil
}
// RecordAnswerWithPriorityReturningID records a user's response, updates priority async, and returns the new user_responses ID
func (s *LearningService) RecordAnswerWithPriorityReturningID(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) (int, error) {
response := &models.UserResponse{
UserID: userID,
QuestionID: questionID,
UserAnswerIndex: answerIndex,
IsCorrect: isCorrect,
ResponseTimeMs: responseTime,
CreatedAt: time.Now(),
}
// Insert and get ID
if err := s.RecordUserResponse(ctx, response); err != nil {
return 0, contextutils.WrapError(err, "failed to record user response")
}
// Update priority score in background
go s.updatePriorityScoreAsync(ctx, userID, questionID)
return response.ID, nil
}
// MarkQuestionAsKnown marks a question as known for a user with optional confidence level
func (s *LearningService) MarkQuestionAsKnown(ctx context.Context, userID, questionID int, confidenceLevel *int) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "mark_question_as_known",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer observability.FinishSpan(span, &err)
// DEBUG: Log the attempt
s.logger.Debug(ctx, "MarkQuestionAsKnown called", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
// Update user_question_metadata table with confidence level
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_question_metadata (user_id, question_id, marked_as_known, marked_as_known_at, confidence_level, created_at, updated_at)
VALUES ($1, $2, TRUE, NOW(), $3, NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE
SET marked_as_known = TRUE, marked_as_known_at = NOW(), confidence_level = $3, updated_at = NOW()
`, userID, questionID, confidenceLevel)
if err != nil {
// DEBUG: Log the actual error
s.logger.Debug(ctx, "MarkQuestionAsKnown error", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"error": err.Error(),
"error_type": fmt.Sprintf("%T", err),
})
if isForeignKeyConstraintViolation(err) {
s.logger.Debug(ctx, "Foreign key constraint violation detected", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
return contextutils.ErrQuestionNotFound
}
s.logger.Debug(ctx, "Not a foreign key constraint violation, returning original error", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
return err
}
s.logger.Debug(ctx, "MarkQuestionAsKnown succeeded", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
// Update priority score in background so the new confidence affects selection immediately
go s.updatePriorityScoreAsync(ctx, userID, questionID)
return nil
}
// GetUserLearningPreferences retrieves user learning preferences
func (s *LearningService) GetUserLearningPreferences(ctx context.Context, userID int) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_learning_preferences",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
var prefs models.UserLearningPreferences
err = s.db.QueryRowContext(ctx, `
SELECT id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
FROM user_learning_preferences
WHERE user_id = $1
`, userID).Scan(
&prefs.ID, &prefs.UserID, &prefs.FocusOnWeakAreas, &prefs.IncludeReviewQuestions,
&prefs.FreshQuestionRatio, &prefs.KnownQuestionPenalty, &prefs.ReviewIntervalDays,
&prefs.WeakAreaBoost, &prefs.DailyReminderEnabled,
&prefs.TTSVoice,
&prefs.LastDailyReminderSent,
&prefs.DailyGoal,
&prefs.CreatedAt, &prefs.UpdatedAt,
)
if err == sql.ErrNoRows {
// Check if user exists before creating default preferences
var userExists bool
err = s.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", userID).Scan(&userExists)
if err != nil {
return nil, contextutils.WrapError(err, "failed to check if user exists")
}
if !userExists {
return nil, contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user %d not found", userID)
}
// Create default preferences if none exist
return s.createDefaultPreferences(ctx, userID)
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user preferences")
}
return &prefs, nil
}
// UpdateLastDailyReminderSent updates the last daily reminder sent timestamp for a user
func (s *LearningService) UpdateLastDailyReminderSent(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "update_last_daily_reminder_sent",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
// Use INSERT ... ON CONFLICT to create the record if it doesn't exist
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_learning_preferences (user_id, last_daily_reminder_sent, updated_at)
VALUES ($1, NOW(), NOW())
ON CONFLICT (user_id) DO UPDATE SET
last_daily_reminder_sent = NOW(),
updated_at = NOW()
`, userID)
if err != nil {
return contextutils.WrapError(err, "failed to update last daily reminder sent")
}
return nil
}
// UpdateUserLearningPreferences updates user learning preferences
func (s *LearningService) UpdateUserLearningPreferences(ctx context.Context, userID int, prefs *models.UserLearningPreferences) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "update_user_learning_preferences",
observability.AttributeUserID(userID),
attribute.Bool("prefs.focus_on_weak_areas", prefs.FocusOnWeakAreas),
attribute.Bool("prefs.include_review_questions", prefs.IncludeReviewQuestions),
attribute.Float64("prefs.fresh_question_ratio", prefs.FreshQuestionRatio),
attribute.Float64("prefs.known_question_penalty", prefs.KnownQuestionPenalty),
attribute.Int("prefs.review_interval_days", prefs.ReviewIntervalDays),
attribute.Float64("prefs.weak_area_boost", prefs.WeakAreaBoost),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var updatedPrefs models.UserLearningPreferences
err = s.db.QueryRowContext(ctx, `
UPDATE user_learning_preferences
SET focus_on_weak_areas = $2, include_review_questions = $3, fresh_question_ratio = $4,
known_question_penalty = $5, review_interval_days = $6, weak_area_boost = $7,
daily_reminder_enabled = $8, tts_voice = $9, daily_goal = COALESCE(NULLIF($10, 0), daily_goal), updated_at = NOW()
WHERE user_id = $1
RETURNING id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
`, userID, prefs.FocusOnWeakAreas, prefs.IncludeReviewQuestions, prefs.FreshQuestionRatio,
prefs.KnownQuestionPenalty, prefs.ReviewIntervalDays, prefs.WeakAreaBoost, prefs.DailyReminderEnabled, prefs.TTSVoice, prefs.DailyGoal).Scan(
&updatedPrefs.ID, &updatedPrefs.UserID, &updatedPrefs.FocusOnWeakAreas, &updatedPrefs.IncludeReviewQuestions,
&updatedPrefs.FreshQuestionRatio, &updatedPrefs.KnownQuestionPenalty, &updatedPrefs.ReviewIntervalDays,
&updatedPrefs.WeakAreaBoost, &updatedPrefs.DailyReminderEnabled, &updatedPrefs.TTSVoice, &updatedPrefs.LastDailyReminderSent,
&updatedPrefs.DailyGoal, &updatedPrefs.CreatedAt, &updatedPrefs.UpdatedAt,
)
if err == sql.ErrNoRows {
// If no preferences exist, create them with the provided values
return s.createPreferencesWithValues(ctx, userID, prefs)
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to update user preferences")
}
return &updatedPrefs, nil
}
// createPreferencesWithValues creates learning preferences for a user with the provided values
func (s *LearningService) createPreferencesWithValues(ctx context.Context, userID int, prefs *models.UserLearningPreferences) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "create_preferences_with_values",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Use the provided values, falling back to defaults for any missing fields
defaultPrefs := s.GetDefaultLearningPreferences()
prefs.UserID = userID
// Merge provided values with defaults
if prefs.FocusOnWeakAreas == defaultPrefs.FocusOnWeakAreas && !prefs.FocusOnWeakAreas {
prefs.FocusOnWeakAreas = defaultPrefs.FocusOnWeakAreas
}
if prefs.IncludeReviewQuestions == defaultPrefs.IncludeReviewQuestions && !prefs.IncludeReviewQuestions {
prefs.IncludeReviewQuestions = defaultPrefs.IncludeReviewQuestions
}
if prefs.FreshQuestionRatio == 0 {
prefs.FreshQuestionRatio = defaultPrefs.FreshQuestionRatio
}
if prefs.KnownQuestionPenalty == 0 {
prefs.KnownQuestionPenalty = defaultPrefs.KnownQuestionPenalty
}
if prefs.ReviewIntervalDays == 0 {
prefs.ReviewIntervalDays = defaultPrefs.ReviewIntervalDays
}
if prefs.WeakAreaBoost == 0 {
prefs.WeakAreaBoost = defaultPrefs.WeakAreaBoost
}
if prefs.DailyGoal == 0 {
prefs.DailyGoal = defaultPrefs.DailyGoal
}
// Try to insert with ON CONFLICT DO NOTHING to handle race conditions
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_learning_preferences (user_id, focus_on_weak_areas, include_review_questions,
fresh_question_ratio, known_question_penalty,
review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, daily_goal, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING
`, userID, prefs.FocusOnWeakAreas, prefs.IncludeReviewQuestions,
prefs.FreshQuestionRatio, prefs.KnownQuestionPenalty,
prefs.ReviewIntervalDays, prefs.WeakAreaBoost, prefs.DailyReminderEnabled, prefs.TTSVoice, prefs.DailyGoal)
if err != nil {
return nil, contextutils.WrapError(err, "failed to create preferences with values")
}
// Now fetch the preferences (either the ones we just created or the ones created by another concurrent request)
err = s.db.QueryRowContext(ctx, `
SELECT id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
FROM user_learning_preferences
WHERE user_id = $1
`, userID).Scan(
&prefs.ID, &prefs.UserID, &prefs.FocusOnWeakAreas, &prefs.IncludeReviewQuestions,
&prefs.FreshQuestionRatio, &prefs.KnownQuestionPenalty, &prefs.ReviewIntervalDays,
&prefs.WeakAreaBoost, &prefs.DailyReminderEnabled, &prefs.TTSVoice, &prefs.LastDailyReminderSent,
&prefs.DailyGoal, &prefs.CreatedAt, &prefs.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to fetch created preferences")
}
return prefs, nil
}
// createDefaultPreferences creates default learning preferences for a user
func (s *LearningService) createDefaultPreferences(ctx context.Context, userID int) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "create_default_preferences",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
defaultPrefs := s.GetDefaultLearningPreferences()
defaultPrefs.UserID = userID
// Try to insert with ON CONFLICT DO NOTHING to handle race conditions
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_learning_preferences (user_id, focus_on_weak_areas, include_review_questions,
fresh_question_ratio, known_question_penalty,
review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, daily_goal, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING
`, userID, defaultPrefs.FocusOnWeakAreas, defaultPrefs.IncludeReviewQuestions,
defaultPrefs.FreshQuestionRatio, defaultPrefs.KnownQuestionPenalty,
defaultPrefs.ReviewIntervalDays, defaultPrefs.WeakAreaBoost, defaultPrefs.DailyReminderEnabled, defaultPrefs.TTSVoice, defaultPrefs.DailyGoal)
if err != nil {
return nil, contextutils.WrapError(err, "failed to create default preferences")
}
// Now fetch the preferences (either the ones we just created or the ones created by another concurrent request)
err = s.db.QueryRowContext(ctx, `
SELECT id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
FROM user_learning_preferences
WHERE user_id = $1
`, userID).Scan(
&defaultPrefs.ID, &defaultPrefs.UserID, &defaultPrefs.FocusOnWeakAreas, &defaultPrefs.IncludeReviewQuestions,
&defaultPrefs.FreshQuestionRatio, &defaultPrefs.KnownQuestionPenalty, &defaultPrefs.ReviewIntervalDays,
&defaultPrefs.WeakAreaBoost, &defaultPrefs.DailyReminderEnabled, &defaultPrefs.TTSVoice, &defaultPrefs.LastDailyReminderSent,
&defaultPrefs.DailyGoal, &defaultPrefs.CreatedAt, &defaultPrefs.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to fetch created preferences")
}
return defaultPrefs, nil
}
// GetDefaultLearningPreferences returns default learning preferences
func (s *LearningService) GetDefaultLearningPreferences() *models.UserLearningPreferences {
return &models.UserLearningPreferences{
FocusOnWeakAreas: true,
IncludeReviewQuestions: true,
FreshQuestionRatio: 0.3,
KnownQuestionPenalty: 0.1,
ReviewIntervalDays: 7,
WeakAreaBoost: 2.0,
DailyReminderEnabled: false, // Default to false for daily reminders
DailyGoal: 10,
TTSVoice: "",
}
}
// CalculatePriorityScore calculates priority score for a specific question for a user
func (s *LearningService) CalculatePriorityScore(ctx context.Context, userID, questionID int) (result0 float64, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "calculate_priority_score",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user preferences
prefs, err := s.GetUserLearningPreferences(ctx, userID)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user preferences: %w", err)
}
// Get user's performance history for this question
performance, err := s.getQuestionPerformance(ctx, userID, questionID)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get question performance: %w", err)
}
// Calculate components
baseScore := 100.0
performanceMultiplier := s.calculatePerformanceMultiplier(performance, prefs.WeakAreaBoost)
spacedRepetitionBoost := s.calculateSpacedRepetitionBoost(performance.LastSeenAt)
userPreferenceMultiplier := s.calculateUserPreferenceMultiplier(performance, prefs)
freshnessBoost := s.calculateFreshnessBoost(performance.TimesAnswered)
// Final score with bounds checking
finalScore := baseScore * performanceMultiplier * spacedRepetitionBoost * userPreferenceMultiplier * freshnessBoost
// Apply bounds to prevent extreme values
if finalScore < 1.0 {
finalScore = 1.0
} else if finalScore > 1000.0 {
finalScore = 1000.0
}
return finalScore, nil
}
// updatePriorityScoreAsync updates priority score for a question asynchronously
func (s *LearningService) updatePriorityScoreAsync(ctx context.Context, userID, questionID int) {
ctx, span := observability.TraceLearningFunction(ctx, "update_priority_score_async",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer span.End()
score, err := s.CalculatePriorityScore(ctx, userID, questionID)
if err != nil {
s.logger.Error(ctx, "Failed to calculate priority score", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
return
}
// Update or insert priority score
_, err = s.db.ExecContext(ctx, `
INSERT INTO question_priority_scores (user_id, question_id, priority_score, last_calculated_at, created_at, updated_at)
VALUES ($1, $2, $3, NOW(), NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE
SET priority_score = $3, last_calculated_at = NOW(), updated_at = NOW()
`, userID, questionID, score)
if err != nil {
s.logger.Error(ctx, "Failed to update priority score", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"score": score,
})
}
}
// QuestionPerformance represents performance data for a specific question
type QuestionPerformance struct {
TimesAnswered int
CorrectAnswers int
LastSeenAt *time.Time
MarkedAsKnown bool
MarkedAsKnownAt *time.Time
ConfidenceLevel *int
}
// getQuestionPerformance retrieves performance data for a specific question
func (s *LearningService) getQuestionPerformance(ctx context.Context, userID, questionID int) (result0 *QuestionPerformance, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_question_performance",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
performance := &QuestionPerformance{}
// Get response statistics
err = s.db.QueryRowContext(ctx, `
SELECT
COUNT(*) as times_answered,
COALESCE(SUM(CASE WHEN is_correct THEN 1 ELSE 0 END), 0) as correct_answers,
MAX(created_at) as last_seen_at
FROM user_responses
WHERE user_id = $1 AND question_id = $2
`, userID, questionID).Scan(
&performance.TimesAnswered,
&performance.CorrectAnswers,
&performance.LastSeenAt,
)
if err != nil && err != sql.ErrNoRows {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get response statistics: %w", err)
}
// Get metadata
var markedAsKnownAt sql.NullTime
var confidenceLevel sql.NullInt32
err = s.db.QueryRowContext(ctx, `
SELECT marked_as_known, marked_as_known_at, confidence_level
FROM user_question_metadata
WHERE user_id = $1 AND question_id = $2
`, userID, questionID).Scan(&performance.MarkedAsKnown, &markedAsKnownAt, &confidenceLevel)
if err != nil && err != sql.ErrNoRows {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get question metadata: %w", err)
}
if markedAsKnownAt.Valid {
performance.MarkedAsKnownAt = &markedAsKnownAt.Time
}
if confidenceLevel.Valid {
level := int(confidenceLevel.Int32)
performance.ConfidenceLevel = &level
}
return performance, nil
}
// calculatePerformanceMultiplier calculates the performance-based multiplier
func (s *LearningService) calculatePerformanceMultiplier(performance *QuestionPerformance, weakAreaBoost float64) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if performance.TimesAnswered == 0 {
return 1.0 // Neutral for new questions
}
errorRate := float64(performance.TimesAnswered-performance.CorrectAnswers) / float64(performance.TimesAnswered)
successRate := float64(performance.CorrectAnswers) / float64(performance.TimesAnswered)
// Apply weak area boost for questions with high error rates
multiplier := 1.0 + (errorRate * weakAreaBoost) - (successRate * 0.5)
// Apply bounds to prevent extreme values
if multiplier < 0.1 {
multiplier = 0.1
} else if multiplier > 10.0 {
multiplier = 10.0
}
return multiplier
}
// calculateSpacedRepetitionBoost calculates the spaced repetition boost
func (s *LearningService) calculateSpacedRepetitionBoost(lastSeenAt *time.Time) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if lastSeenAt == nil {
return 1.0 // No boost for never-seen questions
}
daysSinceLastSeen := time.Since(*lastSeenAt).Hours() / 24.0
boost := 1.0 + (daysSinceLastSeen * 0.1)
// Cap the boost at 5.0x multiplier
return math.Min(boost, 5.0)
}
// calculateUserPreferenceMultiplier calculates how user preference ("mark known" with confidence)
// influences question priority.
//
// New policy:
// - Confidence 1â2: show MORE (boost priority) â multipliers > 1
// - Confidence 3: neutral â multiplier = 1
// - Confidence 4â5: show LESS (reduce priority) â multiplier < 1 using KnownQuestionPenalty
func (s *LearningService) calculateUserPreferenceMultiplier(performance *QuestionPerformance, prefs *models.UserLearningPreferences) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if performance.MarkedAsKnown {
if performance.ConfidenceLevel != nil {
switch *performance.ConfidenceLevel {
case 1:
// Low confidence â increase frequency noticeably
return 1.25
case 2:
// Some confidence â slight increase in frequency
return 1.10
case 3:
// Neutral â no change
return 1.0
case 4:
// Very confident â decrease frequency using half of penalty
return prefs.KnownQuestionPenalty * 0.5
case 5:
// Extremely confident â strong decrease using 10% of penalty
return prefs.KnownQuestionPenalty * 0.1
default:
return 1.0
}
}
// Fallback when confidence not provided â use configured penalty
return prefs.KnownQuestionPenalty
}
return 1.0
}
// calculateFreshnessBoost calculates the freshness boost for new questions
func (s *LearningService) calculateFreshnessBoost(timesAnswered int) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if timesAnswered == 0 {
return 1.5 // Boost for fresh questions
}
return 1.0
}
// isForeignKeyConstraintViolation checks if the error is a foreign key constraint violation
func isForeignKeyConstraintViolation(err error) bool {
if err == nil {
return false
}
// Check for PostgreSQL foreign key constraint violation error code
if pqErr, ok := err.(*pq.Error); ok {
// PostgreSQL error code 23503 is for foreign key constraint violations
if pqErr.Code == "23503" {
return true
}
}
// Also check for the error message pattern as a fallback
errorStr := err.Error()
return strings.Contains(errorStr, "violates foreign key constraint")
}
// Analytics Methods
// GetPriorityScoreDistribution returns the distribution of priority scores
func (s *LearningService) GetPriorityScoreDistribution(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_priority_score_distribution")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(CASE WHEN qps.priority_score > 200 THEN 1 END) as high,
COUNT(CASE WHEN qps.priority_score BETWEEN 100 AND 200 THEN 1 END) as medium,
COUNT(CASE WHEN qps.priority_score < 100 THEN 1 END) as low,
AVG(qps.priority_score) as average
FROM question_priority_scores qps
JOIN questions q ON qps.question_id = q.id
WHERE qps.priority_score > 0
`
var high, medium, low int
var average sql.NullFloat64
err = s.db.QueryRowContext(ctx, query).Scan(&high, &medium, &low, &average)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get priority score distribution: %w", err)
}
result := map[string]interface{}{
"high": high,
"medium": medium,
"low": low,
"average": 0.0,
}
if average.Valid {
result["average"] = average.Float64
}
span.SetAttributes(
attribute.Int("high_count", high),
attribute.Int("medium_count", medium),
attribute.Int("low_count", low),
attribute.Float64("average_score", result["average"].(float64)),
)
return result, nil
}
// GetHighPriorityQuestions returns the highest priority questions
func (s *LearningService) GetHighPriorityQuestions(ctx context.Context, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_high_priority_questions",
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
q.topic_category as topic,
qps.priority_score
FROM question_priority_scores qps
JOIN questions q ON qps.question_id = q.id
WHERE qps.priority_score > 200
ORDER BY qps.priority_score DESC
LIMIT $1
`
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get high priority questions: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []map[string]interface{}
for rows.Next() {
var questionType, level, topic sql.NullString
var priorityScore float64
err = rows.Scan(&questionType, &level, &topic, &priorityScore)
if err != nil {
continue
}
question := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"topic": topic.String,
"priority_score": priorityScore,
}
questions = append(questions, question)
}
span.SetAttributes(attribute.Int("questions_count", len(questions)))
return questions, nil
}
// GetWeakAreasByTopic returns weak areas by topic
func (s *LearningService) GetWeakAreasByTopic(ctx context.Context, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_weak_areas_by_topic",
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
topic,
SUM(total_attempts) as total_attempts,
SUM(correct_attempts) as correct_attempts
FROM performance_metrics
WHERE total_attempts > 0
GROUP BY topic
ORDER BY (SUM(correct_attempts)::float / SUM(total_attempts)) ASC
LIMIT $1
`
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get weak areas: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var weakAreas []map[string]interface{}
for rows.Next() {
var topic sql.NullString
var totalAttempts, correctAttempts int
err = rows.Scan(&topic, &totalAttempts, &correctAttempts)
if err != nil {
continue
}
area := map[string]interface{}{
"topic": topic.String,
"total_attempts": totalAttempts,
"correct_attempts": correctAttempts,
}
weakAreas = append(weakAreas, area)
}
span.SetAttributes(attribute.Int("weak_areas_count", len(weakAreas)))
return weakAreas, nil
}
// GetLearningPreferencesUsage returns learning preferences usage statistics
func (s *LearningService) GetLearningPreferencesUsage(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_learning_preferences_usage")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(*) as total_users,
AVG(focus_on_weak_areas::int) as avg_focus_on_weak_areas,
AVG(fresh_question_ratio) as avg_fresh_question_ratio,
AVG(weak_area_boost) as avg_weak_area_boost,
AVG(known_question_penalty) as avg_known_question_penalty
FROM user_learning_preferences
`
var totalUsers int
var avgFocusOnWeakAreas, avgFreshQuestionRatio, avgWeakAreaBoost, avgKnownQuestionPenalty sql.NullFloat64
err = s.db.QueryRowContext(ctx, query).Scan(
&totalUsers,
&avgFocusOnWeakAreas,
&avgFreshQuestionRatio,
&avgWeakAreaBoost,
&avgKnownQuestionPenalty,
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get learning preferences usage: %w", err)
}
result := map[string]interface{}{
"total_users": 0,
"focusOnWeakAreas": false,
"freshQuestionRatio": 0.3,
"weakAreaBoost": 2.0,
"knownQuestionPenalty": 0.1,
}
if totalUsers > 0 {
result["total_users"] = totalUsers
if avgFocusOnWeakAreas.Valid {
result["focusOnWeakAreas"] = avgFocusOnWeakAreas.Float64 > 0.5
}
if avgFreshQuestionRatio.Valid {
result["freshQuestionRatio"] = avgFreshQuestionRatio.Float64
}
if avgWeakAreaBoost.Valid {
result["weakAreaBoost"] = avgWeakAreaBoost.Float64
}
if avgKnownQuestionPenalty.Valid {
result["knownQuestionPenalty"] = avgKnownQuestionPenalty.Float64
}
}
span.SetAttributes(
attribute.Int("total_users", result["total_users"].(int)),
attribute.Bool("focus_on_weak_areas", result["focusOnWeakAreas"].(bool)),
attribute.Float64("fresh_question_ratio", result["freshQuestionRatio"].(float64)),
attribute.Float64("weak_area_boost", result["weakAreaBoost"].(float64)),
attribute.Float64("known_question_penalty", result["knownQuestionPenalty"].(float64)),
)
return result, nil
}
// GetQuestionTypeGaps returns gaps in question types
func (s *LearningService) GetQuestionTypeGaps(ctx context.Context) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_question_type_gaps")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
COUNT(q.id) as available,
COUNT(qps.question_id) as with_priority_scores
FROM questions q
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id
GROUP BY q.type, q.level
HAVING COUNT(qps.question_id) < COUNT(q.id) * 0.8
ORDER BY (COUNT(qps.question_id)::float / COUNT(q.id)) ASC
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
span.SetAttributes(attribute.String("error.type", "database_query_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get question type gaps: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows in GetQuestionTypeGaps", map[string]interface{}{"error": err.Error()})
}
}()
var gaps []map[string]interface{}
var scanErrors int
for rows.Next() {
var questionType, level sql.NullString
var available, withPriorityScores int
err = rows.Scan(&questionType, &level, &available, &withPriorityScores)
if err != nil {
scanErrors++
span.SetAttributes(attribute.String("error.type", "row_scan_failed"), attribute.String("error", err.Error()))
continue
}
gap := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"available": available,
"demand": available - withPriorityScores,
}
gaps = append(gaps, gap)
}
if err := rows.Err(); err != nil {
span.SetAttributes(attribute.String("error.type", "rows_iteration_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "error during rows iteration: %w", err)
}
span.SetAttributes(
attribute.Int("gaps_count", len(gaps)),
attribute.Int("scan_errors", scanErrors),
)
return gaps, nil
}
// GetGenerationSuggestions returns suggestions for question generation
func (s *LearningService) GetGenerationSuggestions(ctx context.Context) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_generation_suggestions")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
q.language,
COUNT(q.id) as available,
COUNT(CASE WHEN qps.priority_score > 100 THEN 1 END) as high_priority,
AVG(qps.priority_score) as avg_priority
FROM questions q
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id
GROUP BY q.type, q.level, q.language
HAVING COUNT(q.id) < 50 OR COUNT(CASE WHEN qps.priority_score > 100 THEN 1 END) < 10
ORDER BY COUNT(q.id) ASC, AVG(qps.priority_score) DESC
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
span.SetAttributes(attribute.String("error.type", "database_query_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get generation suggestions: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows in GetGenerationSuggestions", map[string]interface{}{"error": err.Error()})
}
}()
var suggestions []map[string]interface{}
var scanErrors int
for rows.Next() {
var questionType, level, language sql.NullString
var available, highPriority int
var avgPriority sql.NullFloat64
err = rows.Scan(&questionType, &level, &language, &available, &highPriority, &avgPriority)
if err != nil {
scanErrors++
span.SetAttributes(attribute.String("error.type", "row_scan_failed"), attribute.String("error", err.Error()))
continue
}
suggestion := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"language": language.String,
"available": available,
"high_priority": highPriority,
"avg_priority": 0.0,
"priority_score": 0.0,
}
if avgPriority.Valid {
suggestion["avg_priority"] = avgPriority.Float64
suggestion["priority_score"] = avgPriority.Float64
}
suggestions = append(suggestions, suggestion)
}
if err := rows.Err(); err != nil {
span.SetAttributes(attribute.String("error.type", "rows_iteration_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "error during rows iteration: %w", err)
}
span.SetAttributes(
attribute.Int("suggestions_count", len(suggestions)),
attribute.Int("scan_errors", scanErrors),
)
return suggestions, nil
}
// GetPrioritySystemPerformance returns performance metrics for the priority system
func (s *LearningService) GetPrioritySystemPerformance(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_priority_system_performance")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// This is a simplified implementation - in a real system, this would track actual performance metrics
query := `
SELECT
COUNT(*) as total_calculations,
AVG(priority_score) as avg_score,
MAX(last_calculated_at) as last_calculation
FROM question_priority_scores
WHERE last_calculated_at > NOW() - INTERVAL '1 hour'
`
var totalCalculations int
var avgScore sql.NullFloat64
var lastCalculation sql.NullTime
err = s.db.QueryRowContext(ctx, query).Scan(&totalCalculations, &avgScore, &lastCalculation)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get priority system performance: %w", err)
}
result := map[string]interface{}{
"calculationsPerSecond": float64(totalCalculations) / 3600.0, // Per hour converted to per second
"avgCalculationTime": 0.0, // Would need to track actual calculation times
"avgQueryTime": 0.0, // Would need to track actual query times
"memoryUsage": 0.0, // Would need to track actual memory usage
"avgScore": 0.0, // Default value
}
if avgScore.Valid {
result["avgScore"] = avgScore.Float64
}
if lastCalculation.Valid {
result["lastCalculation"] = lastCalculation.Time.Format(time.RFC3339)
}
span.SetAttributes(
attribute.Float64("calculations_per_second", result["calculationsPerSecond"].(float64)),
attribute.Float64("avg_score", result["avgScore"].(float64)),
attribute.Int("total_calculations", totalCalculations),
)
return result, nil
}
// GetBackgroundJobsStatus returns the status of background jobs
func (s *LearningService) GetBackgroundJobsStatus(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_background_jobs_status")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// This is a simplified implementation - in a real system, this would track actual background job status
query := `
SELECT
COUNT(*) as total_updates,
MAX(updated_at) as last_update
FROM question_priority_scores
WHERE updated_at > NOW() - INTERVAL '1 minute'
`
var totalUpdates int
var lastUpdate sql.NullTime
err = s.db.QueryRowContext(ctx, query).Scan(&totalUpdates, &lastUpdate)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get background jobs status")
}
result := map[string]interface{}{
"priorityUpdates": totalUpdates,
"lastUpdate": "N/A",
"queueSize": 0, // Would need to track actual queue size
"status": "healthy",
}
if lastUpdate.Valid {
result["lastUpdate"] = lastUpdate.Time.Format(time.RFC3339)
}
if totalUpdates == 0 {
result["status"] = "idle"
}
span.SetAttributes(
attribute.Int("priority_updates", totalUpdates),
attribute.String("status", result["status"].(string)),
attribute.Int("queue_size", result["queueSize"].(int)),
)
return result, nil
}
// GetUserPriorityScoreDistribution returns priority score distribution for a specific user
func (s *LearningService) GetUserPriorityScoreDistribution(ctx context.Context, userID int) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_priority_score_distribution",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(CASE WHEN priority_score > 200 THEN 1 END) as high,
COUNT(CASE WHEN priority_score BETWEEN 100 AND 200 THEN 1 END) as medium,
COUNT(CASE WHEN priority_score < 100 THEN 1 END) as low,
AVG(priority_score) as average
FROM question_priority_scores
WHERE user_id = $1 AND priority_score > 0
`
var high, medium, low int
var average sql.NullFloat64
err = s.db.QueryRowContext(ctx, query, userID).Scan(&high, &medium, &low, &average)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user priority score distribution")
}
result := map[string]interface{}{
"high": high,
"medium": medium,
"low": low,
"average": 0.0,
}
if average.Valid {
result["average"] = average.Float64
}
span.SetAttributes(
attribute.Int("high_count", high),
attribute.Int("medium_count", medium),
attribute.Int("low_count", low),
attribute.Float64("average_score", result["average"].(float64)),
)
return result, nil
}
// GetUserHighPriorityQuestions returns the highest priority questions for a specific user
func (s *LearningService) GetUserHighPriorityQuestions(ctx context.Context, userID, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_high_priority_questions",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
q.topic_category as topic,
qps.priority_score
FROM question_priority_scores qps
JOIN questions q ON qps.question_id = q.id
WHERE qps.user_id = $1 AND qps.priority_score > 200
ORDER BY qps.priority_score DESC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user high priority questions")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []map[string]interface{}
for rows.Next() {
var questionType, level, topic sql.NullString
var priorityScore float64
err = rows.Scan(&questionType, &level, &topic, &priorityScore)
if err != nil {
continue
}
question := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"topic": topic.String,
"priority_score": priorityScore,
}
questions = append(questions, question)
}
span.SetAttributes(attribute.Int("questions_count", len(questions)))
return questions, nil
}
// GetUserWeakAreas returns weak areas for a specific user
func (s *LearningService) GetUserWeakAreas(ctx context.Context, userID, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_weak_areas",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
topic,
total_attempts,
correct_attempts
FROM performance_metrics
WHERE user_id = $1 AND total_attempts > 0
ORDER BY (correct_attempts::float / total_attempts) ASC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user weak areas")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var weakAreas []map[string]interface{}
for rows.Next() {
var topic sql.NullString
var totalAttempts, correctAttempts int
err = rows.Scan(&topic, &totalAttempts, &correctAttempts)
if err != nil {
continue
}
area := map[string]interface{}{
"topic": topic.String,
"total_attempts": totalAttempts,
"correct_attempts": correctAttempts,
}
weakAreas = append(weakAreas, area)
}
span.SetAttributes(attribute.Int("weak_areas_count", len(weakAreas)))
return weakAreas, nil
}
// Priority generation methods moved to worker
// GetHighPriorityTopics returns topics with high average priority scores for a user
func (s *LearningService) GetHighPriorityTopics(ctx context.Context, userID int) (result0 []string, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_high_priority_topics",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.topic_category, AVG(qps.priority_score) as avg_score
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.topic_category IS NOT NULL
AND q.topic_category != ''
GROUP BY q.topic_category
HAVING AVG(qps.priority_score) >= 150.0
ORDER BY avg_score DESC
LIMIT 5
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get high priority topics")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var topics []string
for rows.Next() {
var topic string
var avgScore float64
if err := rows.Scan(&topic, &avgScore); err != nil {
continue
}
topics = append(topics, topic)
}
span.SetAttributes(attribute.Int("topics_count", len(topics)))
// Ensure we always return a slice, not nil
if topics == nil {
topics = []string{}
}
return topics, nil
}
// GetGapAnalysis identifies areas with poor user performance (knowledge gaps)
func (s *LearningService) GetGapAnalysis(ctx context.Context, userID int) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_gap_analysis",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Query to find areas where user has poor performance (low accuracy)
query := `
SELECT
pm.topic,
COUNT(*) as total_questions,
ROUND((pm.correct_attempts * 100.0 / pm.total_attempts), 2) as accuracy_percentage
FROM performance_metrics pm
WHERE pm.user_id = $1
AND pm.total_attempts >= 3
AND (pm.correct_attempts * 100.0 / pm.total_attempts) < 70.0
GROUP BY pm.topic, pm.correct_attempts, pm.total_attempts
ORDER BY accuracy_percentage ASC
LIMIT 10
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get gap analysis")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
gaps := make(map[string]interface{})
for rows.Next() {
var topic string
var totalQuestions int
var accuracyPercentage sql.NullFloat64
if err := rows.Scan(&topic, &totalQuestions, &accuracyPercentage); err != nil {
continue
}
gapInfo := map[string]interface{}{
"topic": topic,
"total_questions": totalQuestions,
"accuracy_percentage": 0.0,
}
if accuracyPercentage.Valid {
gapInfo["accuracy_percentage"] = accuracyPercentage.Float64
}
gaps[topic] = gapInfo
}
span.SetAttributes(attribute.Int("gaps_count", len(gaps)))
return gaps, nil
}
// GetPriorityDistribution returns the distribution of priority scores by topic for a user
func (s *LearningService) GetPriorityDistribution(ctx context.Context, userID int) (result0 map[string]int, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_priority_distribution",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Query to get priority score distribution by topic
query := `
SELECT q.topic_category, COUNT(*) as question_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.topic_category IS NOT NULL
AND q.topic_category != ''
GROUP BY q.topic_category
ORDER BY question_count DESC
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get priority distribution")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
distribution := make(map[string]int)
for rows.Next() {
var topic string
var count int
if err := rows.Scan(&topic, &count); err != nil {
continue
}
distribution[topic] = count
}
span.SetAttributes(attribute.Int("topics_count", len(distribution)))
return distribution, nil
}
// GetUserQuestionConfidenceLevel retrieves the confidence level for a specific question and user
func (s *LearningService) GetUserQuestionConfidenceLevel(ctx context.Context, userID, questionID int) (result0 *int, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_question_confidence_level",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT confidence_level
FROM user_question_metadata
WHERE user_id = $1 AND question_id = $2
`
var confidenceLevel sql.NullInt32
err = s.db.QueryRowContext(ctx, query, userID, questionID).Scan(&confidenceLevel)
if err != nil {
if err == sql.ErrNoRows {
// No confidence level recorded for this user-question pair
return nil, nil
}
return nil, contextutils.WrapError(err, "failed to get user question confidence level")
}
if confidenceLevel.Valid {
level := int(confidenceLevel.Int32)
return &level, nil
}
return nil, nil
}
package services
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// uuidRegex matches standard UUID format (8-4-4-4-12 hex digits)
var uuidRegex = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
// Linear API constants
const (
// LinearAPIEndpoint is the base URL for Linear's GraphQL API
LinearAPIEndpoint = "https://api.linear.app/graphql"
// LinearHTTPTimeout is the timeout for Linear API requests
LinearHTTPTimeout = 30 * time.Second
)
// LinearService handles Linear API integration
type LinearService struct {
config *config.Config
httpClient *http.Client
logger *observability.Logger
apiURL string // Allow overriding API endpoint for testing
}
// LinearIssueResponse represents the response from Linear API
type LinearIssueResponse struct {
Data struct {
IssueCreate struct {
Success bool `json:"success"`
Issue struct {
ID string `json:"id"`
Title string `json:"title"`
URL string `json:"url"`
} `json:"issue"`
} `json:"issueCreate"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
Extensions map[string]interface{} `json:"extensions,omitempty"`
Path []interface{} `json:"path,omitempty"`
} `json:"errors,omitempty"`
}
// LinearIssueResult represents the result of creating a Linear issue
type LinearIssueResult struct {
IssueID string `json:"issue_id"`
IssueURL string `json:"issue_url"`
Title string `json:"title"`
}
// NewLinearService creates a new Linear service instance
func NewLinearService(cfg *config.Config, logger *observability.Logger) *LinearService {
return &LinearService{
config: cfg,
httpClient: &http.Client{
Timeout: LinearHTTPTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
},
logger: logger,
apiURL: LinearAPIEndpoint,
}
}
// NewLinearServiceWithURL creates a new LinearService instance with a custom API URL (for testing)
func NewLinearServiceWithURL(cfg *config.Config, logger *observability.Logger, apiURL string) *LinearService {
return &LinearService{
config: cfg,
httpClient: &http.Client{
Timeout: LinearHTTPTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
},
logger: logger,
apiURL: apiURL,
}
}
// getTeamIDByName looks up a team ID by name, or returns the ID if it's already a UUID
func (s *LinearService) getTeamIDByName(ctx context.Context, teamIdentifier string) (string, error) {
// If it looks like a UUID, return it as-is (case-insensitive check)
if uuidRegex.MatchString(strings.ToLower(teamIdentifier)) {
return teamIdentifier, nil
}
// Otherwise, query Linear for teams
query := `
query Teams {
teams {
nodes {
id
name
}
}
}
`
requestBody := map[string]interface{}{
"query": query,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", contextutils.WrapError(err, "failed to marshal team lookup request")
}
apiURL := s.apiURL
if apiURL == "" {
apiURL = LinearAPIEndpoint
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", contextutils.WrapError(err, "failed to create team lookup request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", s.config.Linear.APIKey)
req.Header.Set("User-Agent", "quizapp/1.0")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to query Linear teams")
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": closeErr.Error()})
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", contextutils.WrapError(err, "failed to read team lookup response")
}
if resp.StatusCode != http.StatusOK {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API returned status %d when looking up teams: %s", resp.StatusCode, string(body)),
"",
)
}
var teamResponse struct {
Data struct {
Teams struct {
Nodes []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"nodes"`
} `json:"teams"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
if err := json.Unmarshal(body, &teamResponse); err != nil {
return "", contextutils.WrapError(err, "failed to unmarshal team lookup response")
}
if len(teamResponse.Errors) > 0 {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API error when looking up teams: %s", teamResponse.Errors[0].Message),
"",
)
}
// Find team by name (case-insensitive)
for _, team := range teamResponse.Data.Teams.Nodes {
if strings.EqualFold(team.Name, teamIdentifier) {
return team.ID, nil
}
}
return "", contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityError,
fmt.Sprintf("Team '%s' not found in Linear", teamIdentifier),
"",
)
}
// getProjectIDByName looks up a project ID by name within a team, or returns the ID if it's already a UUID
func (s *LinearService) getProjectIDByName(ctx context.Context, projectIdentifier, teamID string) (string, error) {
// If it looks like a UUID, return it as-is (case-insensitive check)
if uuidRegex.MatchString(strings.ToLower(projectIdentifier)) {
return projectIdentifier, nil
}
// Otherwise, query Linear for projects in the team
query := `
query Projects($teamId: String!) {
team(id: $teamId) {
projects {
nodes {
id
name
}
}
}
}
`
variables := map[string]interface{}{
"teamId": teamID,
}
requestBody := map[string]interface{}{
"query": query,
"variables": variables,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", contextutils.WrapError(err, "failed to marshal project lookup request")
}
apiURL := s.apiURL
if apiURL == "" {
apiURL = LinearAPIEndpoint
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", contextutils.WrapError(err, "failed to create project lookup request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", s.config.Linear.APIKey)
req.Header.Set("User-Agent", "quizapp/1.0")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to query Linear projects")
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": closeErr.Error()})
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", contextutils.WrapError(err, "failed to read project lookup response")
}
if resp.StatusCode != http.StatusOK {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API returned status %d when looking up projects: %s", resp.StatusCode, string(body)),
"",
)
}
var projectResponse struct {
Data struct {
Team struct {
Projects struct {
Nodes []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"nodes"`
} `json:"projects"`
} `json:"team"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
if err := json.Unmarshal(body, &projectResponse); err != nil {
return "", contextutils.WrapError(err, "failed to unmarshal project lookup response")
}
if len(projectResponse.Errors) > 0 {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API error when looking up projects: %s", projectResponse.Errors[0].Message),
"",
)
}
// Find project by name (case-insensitive)
for _, project := range projectResponse.Data.Team.Projects.Nodes {
if strings.EqualFold(project.Name, projectIdentifier) {
return project.ID, nil
}
}
return "", contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityError,
fmt.Sprintf("Project '%s' not found in team", projectIdentifier),
"",
)
}
// getLabelIDByName looks up a label ID by name, or returns the ID if it's already a UUID
func (s *LinearService) getLabelIDByName(ctx context.Context, labelIdentifier string) (string, error) {
// If it looks like a UUID, return it as-is
if len(labelIdentifier) == 36 && strings.Contains(labelIdentifier, "-") {
return labelIdentifier, nil
}
// Query Linear for both organization and team labels
// First try organization-level labels (workspace-wide)
query := `
query Labels {
organization {
labels {
nodes {
id
name
}
}
}
}
`
requestBody := map[string]interface{}{
"query": query,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", contextutils.WrapError(err, "failed to marshal label lookup request")
}
apiURL := s.apiURL
if apiURL == "" {
apiURL = LinearAPIEndpoint
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", contextutils.WrapError(err, "failed to create label lookup request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", s.config.Linear.APIKey)
req.Header.Set("User-Agent", "quizapp/1.0")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to query Linear labels")
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": closeErr.Error()})
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", contextutils.WrapError(err, "failed to read label lookup response")
}
if resp.StatusCode != http.StatusOK {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API returned status %d when looking up labels: %s", resp.StatusCode, string(body)),
"",
)
}
var labelResponse struct {
Data struct {
Organization struct {
Labels struct {
Nodes []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"nodes"`
} `json:"labels"`
} `json:"organization"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
if err := json.Unmarshal(body, &labelResponse); err != nil {
return "", contextutils.WrapError(err, "failed to unmarshal label lookup response")
}
if len(labelResponse.Errors) > 0 {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API error when looking up labels: %s", labelResponse.Errors[0].Message),
"",
)
}
// Find label by name (case-insensitive) in organization labels
for _, label := range labelResponse.Data.Organization.Labels.Nodes {
if strings.EqualFold(label.Name, labelIdentifier) {
return label.ID, nil
}
}
// If not found in organization labels, try team-specific labels
// Note: We need the team ID to query team labels, but we don't have it here
// For now, we'll return an error. In the future, we could pass teamID to this function
// or query team labels separately in CreateIssue after we have the team ID
return "", contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityError,
fmt.Sprintf("Label '%s' not found in Linear workspace. Make sure the label exists at the workspace level (Settings > Workspace > Labels)", labelIdentifier),
"",
)
}
// getTeamLabelIDByName looks up a team-specific label ID by name
func (s *LinearService) getTeamLabelIDByName(ctx context.Context, teamID, labelIdentifier string) (string, error) {
// Query Linear for team-specific labels
query := `
query TeamLabels($teamId: String!) {
team(id: $teamId) {
labels {
nodes {
id
name
}
}
}
}
`
requestBody := map[string]interface{}{
"query": query,
"variables": map[string]interface{}{
"teamId": teamID,
},
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", contextutils.WrapError(err, "failed to marshal team label lookup request")
}
apiURL := s.apiURL
if apiURL == "" {
apiURL = LinearAPIEndpoint
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", contextutils.WrapError(err, "failed to create team label lookup request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", s.config.Linear.APIKey)
req.Header.Set("User-Agent", "quizapp/1.0")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to query Linear team labels")
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": closeErr.Error()})
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", contextutils.WrapError(err, "failed to read team label lookup response")
}
if resp.StatusCode != http.StatusOK {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API returned status %d when looking up team labels: %s", resp.StatusCode, string(body)),
"",
)
}
var labelResponse struct {
Data struct {
Team struct {
Labels struct {
Nodes []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"nodes"`
} `json:"labels"`
} `json:"team"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
if err := json.Unmarshal(body, &labelResponse); err != nil {
return "", contextutils.WrapError(err, "failed to unmarshal team label lookup response")
}
if len(labelResponse.Errors) > 0 {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API error when looking up team labels: %s", labelResponse.Errors[0].Message),
"",
)
}
// Find label by name (case-insensitive)
for _, label := range labelResponse.Data.Team.Labels.Nodes {
if strings.EqualFold(label.Name, labelIdentifier) {
return label.ID, nil
}
}
return "", contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityError,
fmt.Sprintf("Label '%s' not found in Linear team", labelIdentifier),
"",
)
}
// getProjectLabelIDByName looks up a project-specific label ID by name
func (s *LinearService) getProjectLabelIDByName(ctx context.Context, projectID, labelIdentifier string) (string, error) {
// Query Linear for project-specific labels
query := `
query ProjectLabels($projectId: String!) {
project(id: $projectId) {
labels {
nodes {
id
name
}
}
}
}
`
requestBody := map[string]interface{}{
"query": query,
"variables": map[string]interface{}{
"projectId": projectID,
},
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", contextutils.WrapError(err, "failed to marshal project label lookup request")
}
apiURL := s.apiURL
if apiURL == "" {
apiURL = LinearAPIEndpoint
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
return "", contextutils.WrapError(err, "failed to create project label lookup request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", s.config.Linear.APIKey)
req.Header.Set("User-Agent", "quizapp/1.0")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to query Linear project labels")
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": closeErr.Error()})
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", contextutils.WrapError(err, "failed to read project label lookup response")
}
if resp.StatusCode != http.StatusOK {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API returned status %d when looking up project labels: %s", resp.StatusCode, string(body)),
"",
)
}
var labelResponse struct {
Data struct {
Project struct {
Labels struct {
Nodes []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"nodes"`
} `json:"labels"`
} `json:"project"`
} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
if err := json.Unmarshal(body, &labelResponse); err != nil {
return "", contextutils.WrapError(err, "failed to unmarshal project label lookup response")
}
if len(labelResponse.Errors) > 0 {
return "", contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API error when looking up project labels: %s", labelResponse.Errors[0].Message),
"",
)
}
// Find label by name (case-insensitive)
for _, label := range labelResponse.Data.Project.Labels.Nodes {
if strings.EqualFold(label.Name, labelIdentifier) {
return label.ID, nil
}
}
return "", contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityError,
fmt.Sprintf("Label '%s' not found in Linear project", labelIdentifier),
"",
)
}
// CreateIssue creates a new issue in Linear
func (s *LinearService) CreateIssue(ctx context.Context, title, description, teamID, projectID string, labels []string, state string) (result *LinearIssueResult, err error) {
ctx, span := observability.TraceFunction(ctx, "linear", "create_issue",
attribute.String("linear.title", title),
attribute.String("linear.team_id", teamID),
attribute.String("linear.project_id", projectID),
)
defer observability.FinishSpan(span, &err)
if !s.config.Linear.Enabled {
err = contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
"Linear integration is disabled",
"",
)
return nil, err
}
if s.config.Linear.APIKey == "" {
err = contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
"Linear API key is not configured",
"",
)
return nil, err
}
if teamID == "" {
teamID = s.config.Linear.TeamID
if teamID == "" {
err = contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityError,
"Linear team ID or name is required",
"",
)
return nil, err
}
}
// Look up team ID by name if it's not a UUID
actualTeamID, err := s.getTeamIDByName(ctx, teamID)
if err != nil {
return nil, err
}
teamID = actualTeamID
// Use default project ID if none provided and resolve it
actualProjectID := projectID
if actualProjectID == "" {
actualProjectID = s.config.Linear.ProjectID
}
// Look up project ID by name if provided and not a UUID (needed for project label lookup)
if actualProjectID != "" {
resolvedProjectID, err := s.getProjectIDByName(ctx, actualProjectID, teamID)
if err != nil {
// If project lookup fails, log warning but continue without project
s.logger.Warn(ctx, "Failed to look up Linear project, continuing without project", map[string]interface{}{
"project_identifier": actualProjectID,
"error": err.Error(),
})
actualProjectID = "" // Don't include project if lookup failed
} else {
actualProjectID = resolvedProjectID
}
}
// Look up label IDs by name if provided
// Try organization labels first, then team labels, then project labels
var labelIDs []string
if len(labels) > 0 {
for _, labelName := range labels {
labelID, err := s.getLabelIDByName(ctx, labelName)
if err != nil {
// Try team-specific labels as fallback
labelID, err = s.getTeamLabelIDByName(ctx, teamID, labelName)
if err != nil {
// Try project-specific labels if project ID is available
if actualProjectID != "" {
labelID, err = s.getProjectLabelIDByName(ctx, actualProjectID, labelName)
if err != nil {
// Log warning but continue without this label
s.logger.Warn(ctx, "Failed to look up Linear label (tried organization, team, and project labels), continuing without it", map[string]interface{}{
"label_name": labelName,
"team_id": teamID,
"project_id": actualProjectID,
"error": err.Error(),
})
continue
}
} else {
// Log warning but continue without this label
s.logger.Warn(ctx, "Failed to look up Linear label (tried organization and team labels), continuing without it", map[string]interface{}{
"label_name": labelName,
"team_id": teamID,
"error": err.Error(),
})
continue
}
}
}
labelIDs = append(labelIDs, labelID)
}
} else if len(s.config.Linear.DefaultLabels) > 0 {
// Use default labels if none provided
for _, labelName := range s.config.Linear.DefaultLabels {
labelID, err := s.getLabelIDByName(ctx, labelName)
if err != nil {
// Try team-specific labels as fallback
labelID, err = s.getTeamLabelIDByName(ctx, teamID, labelName)
if err != nil {
// Try project-specific labels if project ID is available
if actualProjectID != "" {
labelID, err = s.getProjectLabelIDByName(ctx, actualProjectID, labelName)
if err != nil {
// Log warning but continue without this label
s.logger.Warn(ctx, "Failed to look up default Linear label (tried organization, team, and project labels), continuing without it", map[string]interface{}{
"label_name": labelName,
"team_id": teamID,
"project_id": actualProjectID,
"error": err.Error(),
})
continue
}
} else {
// Log warning but continue without this label
s.logger.Warn(ctx, "Failed to look up default Linear label (tried organization and team labels), continuing without it", map[string]interface{}{
"label_name": labelName,
"team_id": teamID,
"error": err.Error(),
})
continue
}
}
}
labelIDs = append(labelIDs, labelID)
}
}
// Use default state if none provided
// Note: State is not yet implemented (requires fetching state ID from Linear)
if state == "" {
_ = s.config.Linear.DefaultState // Will be used when state ID lookup is implemented
}
projectID = actualProjectID
// Build GraphQL mutation
// Required fields: teamId, title
// Optional fields: description, projectId, assigneeId, labelIds (array of IDs), stateId (ID, not name)
mutation := `
mutation IssueCreate($input: IssueCreateInput!) {
issueCreate(input: $input) {
success
issue {
id
title
url
}
}
}
`
input := map[string]interface{}{
"title": title,
"teamId": teamID,
}
// Only add description if it's not empty (Linear may reject empty strings)
if description != "" {
input["description"] = description
}
// Add project ID if provided (Linear accepts projectId as UUID or name)
// Note: Linear expects projectId to be a valid UUID or identifier
if projectID != "" {
input["projectId"] = projectID
}
// Add label IDs if any were resolved
if len(labelIDs) > 0 {
input["labelIds"] = labelIDs
}
variables := map[string]interface{}{
"input": input,
}
requestBody := map[string]interface{}{
"query": mutation,
"variables": variables,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to marshal GraphQL request")
}
apiURL := s.apiURL
if apiURL == "" {
apiURL = LinearAPIEndpoint
}
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonData))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create HTTP request")
}
req.Header.Set("Content-Type", "application/json")
// Personal API keys should NOT use "Bearer" prefix per Linear docs
// OAuth2 tokens use "Bearer" prefix, but personal API keys use the key directly
req.Header.Set("Authorization", s.config.Linear.APIKey)
req.Header.Set("User-Agent", "quizapp/1.0")
startTime := time.Now()
resp, err := s.httpClient.Do(req)
duration := time.Since(startTime)
if err != nil {
s.logger.Error(ctx, "Linear HTTP request failed", err, map[string]interface{}{
"duration": duration.String(),
})
span.SetAttributes(
attribute.String("error", err.Error()),
attribute.String("duration", duration.String()),
)
return nil, contextutils.WrapErrorf(err, "Linear HTTP request failed after %v", duration)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{
"error": cerr.Error(),
})
}
}()
span.SetAttributes(
attribute.Int("http.status_code", resp.StatusCode),
attribute.String("duration", duration.String()),
)
body, err := io.ReadAll(resp.Body)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to read response body")
}
if resp.StatusCode != http.StatusOK {
s.logger.Error(ctx, "Linear API returned non-200 status", nil, map[string]interface{}{
"status_code": resp.StatusCode,
"body": string(body),
})
span.SetAttributes(
attribute.String("error", fmt.Sprintf("Linear API returned status %d", resp.StatusCode)),
attribute.String("response_body", string(body)),
)
return nil, contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
fmt.Sprintf("Linear API returned status %d: %s", resp.StatusCode, string(body)),
"",
)
}
var linearResp LinearIssueResponse
if err := json.Unmarshal(body, &linearResp); err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to unmarshal Linear response")
}
// Check for GraphQL errors
if len(linearResp.Errors) > 0 {
errorMsg := linearResp.Errors[0].Message
// Log full error details including extensions which may contain validation details
errorDetails := make([]map[string]interface{}, len(linearResp.Errors))
for i, err := range linearResp.Errors {
errorDetails[i] = map[string]interface{}{
"message": err.Message,
}
if len(err.Extensions) > 0 {
errorDetails[i]["extensions"] = err.Extensions
}
if len(err.Path) > 0 {
errorDetails[i]["path"] = err.Path
}
}
// Build detailed error message with all error information
var detailedErrorMsg strings.Builder
detailedErrorMsg.WriteString(errorMsg)
if len(linearResp.Errors[0].Extensions) > 0 {
detailedErrorMsg.WriteString("\nExtensions: ")
extJSON, _ := json.Marshal(linearResp.Errors[0].Extensions)
detailedErrorMsg.WriteString(string(extJSON))
}
if len(linearResp.Errors[0].Path) > 0 {
detailedErrorMsg.WriteString("\nPath: ")
pathJSON, _ := json.Marshal(linearResp.Errors[0].Path)
detailedErrorMsg.WriteString(string(pathJSON))
}
s.logger.Error(ctx, "Linear GraphQL error", nil, map[string]interface{}{
"errors": errorDetails,
"request_body": string(jsonData), // Log the request for debugging
"full_response": string(body), // Log full response for debugging
})
span.SetAttributes(attribute.String("error", detailedErrorMsg.String()))
return nil, contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
detailedErrorMsg.String(),
"",
)
}
if !linearResp.Data.IssueCreate.Success {
s.logger.Error(ctx, "Linear issue creation failed", nil, map[string]interface{}{})
span.SetAttributes(attribute.String("error", "Linear issue creation was not successful"))
return nil, contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
"Linear issue creation was not successful",
"",
)
}
issue := linearResp.Data.IssueCreate.Issue
// Construct the URL if not provided (Linear sometimes doesn't return it)
issueURL := issue.URL
if issueURL == "" {
issueURL = fmt.Sprintf("https://linear.app/issue/%s", issue.ID)
}
result = &LinearIssueResult{
IssueID: issue.ID,
IssueURL: issueURL,
Title: issue.Title,
}
s.logger.Info(ctx, "Linear issue created successfully", map[string]interface{}{
"issue_id": issue.ID,
"issue_url": issueURL,
"duration": duration.String(),
})
span.SetAttributes(
attribute.String("linear.issue_id", issue.ID),
attribute.String("linear.issue_url", issueURL),
)
return result, nil
}
package services
import (
"fmt"
contextutils "quizapp/internal/utils"
)
// NoQuestionsAvailableError is returned when no suitable questions can be found for assignment.
type NoQuestionsAvailableError struct {
Language string
Level string
CandidateIDs []int
CandidateCount int
TotalMatching int
}
func (e *NoQuestionsAvailableError) Error() string {
return fmt.Sprintf("no questions available for assignment (language=%s level=%s candidate_count=%d total_matching=%d)", e.Language, e.Level, e.CandidateCount, e.TotalMatching)
}
// Unwrap allows errors.Is(..., contextutils.ErrNoQuestionsAvailable) to work.
func (e *NoQuestionsAvailableError) Unwrap() error {
return contextutils.ErrNoQuestionsAvailable
}
package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
)
// ErrSignupsDisabled is returned when user registration is disabled by config
var ErrSignupsDisabled = errors.New("user registration is currently disabled")
// OAuth sentinel errors
var (
ErrOAuthCodeAlreadyUsed = errors.New("authorization code has already been used")
ErrOAuthClientConfig = errors.New("OAuth client configuration error")
ErrOAuthInvalidRequest = errors.New("invalid OAuth request")
ErrOAuthUnauthorized = errors.New("OAuth client is not authorized")
ErrOAuthUnsupportedGrant = errors.New("unsupported OAuth grant type")
)
// OAuthService handles OAuth authentication flows
type OAuthService struct {
config *config.Config
TokenEndpoint string // for testing/mocking
UserInfoEndpoint string // for testing/mocking
logger *observability.Logger
}
// NewOAuthServiceWithLogger creates a new OAuth service with logger
func NewOAuthServiceWithLogger(cfg *config.Config, logger *observability.Logger) *OAuthService {
return &OAuthService{
config: cfg,
TokenEndpoint: "https://oauth2.googleapis.com/token",
UserInfoEndpoint: "https://www.googleapis.com/oauth2/v2/userinfo",
logger: logger,
}
}
// GoogleUserInfo represents the user information returned by Google OAuth
type GoogleUserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
VerifiedEmail bool `json:"verified_email"`
}
// GoogleTokenResponse represents the token response from Google OAuth
type GoogleTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
// GetGoogleAuthURL generates the Google OAuth authorization URL
func (s *OAuthService) GetGoogleAuthURL(ctx context.Context, state string) string {
_, span := observability.TraceOAuthFunction(ctx, "get_google_auth_url",
attribute.String("oauth.state", state),
attribute.String("oauth.client_id", s.config.GoogleOAuthClientID),
attribute.String("oauth.redirect_url", s.config.GoogleOAuthRedirectURL),
)
defer span.End()
// Debug logging
if s.config.GoogleOAuthClientID == "" {
if s.logger != nil {
s.logger.Warn(ctx, "Google OAuth client ID is not set", map[string]interface{}{"env_var": "GOOGLE_OAUTH_CLIENT_ID"})
}
}
if s.config.GoogleOAuthRedirectURL == "" {
if s.logger != nil {
s.logger.Warn(ctx, "Google OAuth redirect URL is not set", map[string]interface{}{"env_var": "GOOGLE_OAUTH_REDIRECT_URL"})
}
}
params := url.Values{}
params.Set("client_id", s.config.GoogleOAuthClientID)
params.Set("redirect_uri", s.config.GoogleOAuthRedirectURL)
params.Set("response_type", "code")
params.Set("scope", "openid email profile")
params.Set("state", state)
params.Set("access_type", "offline")
params.Set("prompt", "consent")
return fmt.Sprintf("https://accounts.google.com/o/oauth2/v2/auth?%s", params.Encode())
}
// ExchangeCodeForToken exchanges the authorization code for an access token
func (s *OAuthService) ExchangeCodeForToken(ctx context.Context, code string) (result0 *GoogleTokenResponse, err error) {
ctx, span := observability.TraceOAuthFunction(ctx, "exchange_code_for_token",
attribute.String("oauth.code", code),
attribute.String("oauth.token_endpoint", s.TokenEndpoint),
)
defer observability.FinishSpan(span, &err)
data := url.Values{}
data.Set("client_id", s.config.GoogleOAuthClientID)
data.Set("client_secret", s.config.GoogleOAuthClientSecret)
data.Set("code", code)
data.Set("grant_type", "authorization_code")
data.Set("redirect_uri", s.config.GoogleOAuthRedirectURL)
tokenURL := s.TokenEndpoint
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Use instrumented HTTP client for automatic tracing with explicit span options
client := &http.Client{
Timeout: config.OAuthHTTPTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to exchange code for token")
}
defer func() {
cerr := resp.Body.Close()
if cerr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": cerr.Error()})
}
}()
span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
// Try to parse the error response for better error messages
var errorResp struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
if json.Unmarshal(body, &errorResp) == nil {
span.SetAttributes(
attribute.String("oauth.error", errorResp.Error),
attribute.String("oauth.error_description", errorResp.ErrorDescription),
)
switch errorResp.Error {
case "invalid_grant":
return nil, contextutils.WrapErrorf(ErrOAuthCodeAlreadyUsed, "please try signing in again")
case "invalid_client":
return nil, contextutils.WrapError(ErrOAuthClientConfig, "")
case "invalid_request":
return nil, contextutils.WrapError(ErrOAuthInvalidRequest, "")
case "unauthorized_client":
return nil, contextutils.WrapError(ErrOAuthUnauthorized, "")
case "unsupported_grant_type":
return nil, contextutils.WrapError(ErrOAuthUnsupportedGrant, "")
default:
return nil, contextutils.WrapErrorf(contextutils.ErrOAuthProviderError, "OAuth error: %s - %s", errorResp.Error, errorResp.ErrorDescription)
}
}
return nil, contextutils.WrapErrorf(contextutils.ErrOAuthProviderError, "token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp GoogleTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to decode token response")
}
span.SetAttributes(
attribute.String("oauth.token_type", tokenResp.TokenType),
attribute.Int("oauth.expires_in", tokenResp.ExpiresIn),
)
return &tokenResp, nil
}
// GetGoogleUserInfo retrieves user information from Google using the access token
func (s *OAuthService) GetGoogleUserInfo(ctx context.Context, accessToken string) (result0 *GoogleUserInfo, err error) {
ctx, span := observability.TraceOAuthFunction(ctx, "get_google_user_info",
attribute.String("oauth.userinfo_endpoint", s.UserInfoEndpoint),
)
defer observability.FinishSpan(span, &err)
userinfoURL := s.UserInfoEndpoint
if userinfoURL == "" {
userinfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
}
req, err := http.NewRequest("GET", userinfoURL, nil)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create userinfo request")
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Use instrumented HTTP client for automatic tracing with explicit span options
client := &http.Client{
Timeout: config.OAuthHTTPTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to get user info")
}
defer func() {
cerr := resp.Body.Close()
if cerr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": cerr.Error()})
}
}()
span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
span.SetAttributes(attribute.String("error", fmt.Sprintf("userinfo request failed with status %d: %s", resp.StatusCode, string(body))))
return nil, contextutils.WrapErrorf(contextutils.ErrOAuthProviderError, "userinfo request failed with status %d: %s", resp.StatusCode, string(body))
}
var userInfo GoogleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to decode user info")
}
span.SetAttributes(
attribute.String("user.email", userInfo.Email),
attribute.String("user.id", userInfo.ID),
attribute.Bool("user.verified_email", userInfo.VerifiedEmail),
)
return &userInfo, nil
}
// AuthenticateGoogleUser handles the complete Google OAuth flow
func (s *OAuthService) AuthenticateGoogleUser(ctx context.Context, code string, userService UserServiceInterface) (result0 *models.User, err error) {
ctx, span := observability.TraceOAuthFunction(ctx, "authenticate_google_user",
attribute.String("oauth.code", code),
)
defer observability.FinishSpan(span, &err)
// Exchange code for token
tokenResp, err := s.ExchangeCodeForToken(ctx, code)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to exchange code for token")
}
// Get user info from Google
userInfo, err := s.GetGoogleUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to get user info")
}
span.SetAttributes(
attribute.String("user.email", userInfo.Email),
attribute.String("user.id", userInfo.ID),
)
// Check if user exists by email
existingUser, err := userService.GetUserByEmail(ctx, userInfo.Email)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to check existing user")
}
if existingUser != nil {
// User exists, return the user
span.SetAttributes(
attribute.Int("user.id", existingUser.ID),
attribute.String("auth.result", "existing_user"),
)
return existingUser, nil
}
// Check if signups are disabled before creating new user
if s.config != nil && s.config.IsSignupDisabled() {
// Check if OAuth signup is allowed via whitelist
if !s.config.IsOAuthSignupAllowed(userInfo.Email) {
span.SetAttributes(
attribute.String("auth.result", "oauth_signup_blocked"),
attribute.String("user.email", userInfo.Email),
)
return nil, ErrSignupsDisabled
}
// Allow OAuth signup for whitelisted email/domain
span.SetAttributes(
attribute.String("auth.result", "oauth_signup_allowed"),
attribute.String("user.email", userInfo.Email),
)
}
// User doesn't exist, create new user
// Use email as username (we'll handle conflicts)
username := userInfo.Email
email := userInfo.Email
// Check if username already exists, if so, append a number
counter := 1
for {
existingUser, err := userService.GetUserByUsername(ctx, username)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to check username availability")
}
if existingUser == nil {
break
}
username = fmt.Sprintf("%s_%d", userInfo.Email, counter)
counter++
}
span.SetAttributes(
attribute.String("user.username", username),
attribute.String("user.email", email),
attribute.String("auth.result", "new_user"),
)
// Create user with default settings
// Use email as username (we'll handle conflicts)
user, err := userService.CreateUserWithEmailAndTimezone(ctx, username, email, "UTC", "italian", "beginner")
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create user")
}
span.SetAttributes(attribute.Int("user.id", user.ID))
return user, nil
}
package services
import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// QuestionServiceInterface defines the interface for question-related operations.
// This allows for easier mocking in tests.
type QuestionServiceInterface interface {
SaveQuestion(ctx context.Context, question *models.Question) error
AssignQuestionToUser(ctx context.Context, questionID, userID int) error
GetQuestionByID(ctx context.Context, id int) (*models.Question, error)
GetQuestionWithStats(ctx context.Context, id int) (*QuestionWithStats, error)
GetQuestionsByFilter(ctx context.Context, userID int, language, level string, questionType models.QuestionType, limit int) ([]models.Question, error)
GetNextQuestion(ctx context.Context, userID int, language, level string, qType models.QuestionType) (*QuestionWithStats, error)
GetAdaptiveQuestionsForDaily(ctx context.Context, userID int, language, level string, limit int) ([]*QuestionWithStats, error)
ReportQuestion(ctx context.Context, questionID, userID int, reportReason string) error
GetQuestionStats(ctx context.Context) (map[string]interface{}, error)
GetDetailedQuestionStats(ctx context.Context) (map[string]interface{}, error)
GetRecentQuestionContentsForUser(ctx context.Context, userID, limit int) ([]string, error)
GetReportedQuestions(ctx context.Context) ([]*ReportedQuestionWithUser, error)
MarkQuestionAsFixed(ctx context.Context, questionID int) error
UpdateQuestion(ctx context.Context, questionID int, content map[string]interface{}, correctAnswerIndex int, explanation string) error
DeleteQuestion(ctx context.Context, questionID int) error
GetUserQuestions(ctx context.Context, userID, limit int) ([]*models.Question, error)
GetUserQuestionsWithStats(ctx context.Context, userID, limit int) ([]*QuestionWithStats, error)
GetQuestionsPaginated(ctx context.Context, userID, page, pageSize int, search, typeFilter, statusFilter string) ([]*QuestionWithStats, int, error)
GetAllQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, statusFilter, languageFilter, levelFilter string, userID *int) ([]*QuestionWithStats, int, error)
GetReportedQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, languageFilter, levelFilter string) ([]*QuestionWithStats, int, error)
GetReportedQuestionsStats(ctx context.Context) (map[string]interface{}, error)
GetUserQuestionCount(ctx context.Context, userID int) (int, error)
GetUserResponseCount(ctx context.Context, userID int) (int, error)
GetRandomGlobalQuestionForUser(ctx context.Context, userID int, language, level string, qType models.QuestionType) (*QuestionWithStats, error)
GetUsersForQuestion(ctx context.Context, questionID int) ([]*models.User, int, error)
AssignUsersToQuestion(ctx context.Context, questionID int, userIDs []int) error
UnassignUsersFromQuestion(ctx context.Context, questionID int, userIDs []int) error
DB() *sql.DB
}
// QuestionService provides methods for question management.
type QuestionService struct {
db *sql.DB
learningService *LearningService
logger *observability.Logger
cfg *config.Config
}
// Shared query constants to eliminate duplication
const (
// questionSelectFields contains all question fields for SELECT queries
questionSelectFields = `id, type, language, level, difficulty_score, content, correct_answer, explanation, created_at, status, topic_category, grammar_focus, vocabulary_domain, scenario, style_modifier, difficulty_modifier, time_context`
)
// scanQuestionFromRow scans a database row into a models.Question struct
func (s *QuestionService) scanQuestionFromRow(row *sql.Row) (result0 *models.Question, err error) {
question := &models.Question{}
var contentJSON string
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = row.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
question.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
question.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
question.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
question.Scenario = scenario.String
}
if styleModifier.Valid {
question.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
question.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
question.TimeContext = timeContext.String
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return question, nil
}
// scanQuestionFromRows scans a database rows into a models.Question struct
func (s *QuestionService) scanQuestionFromRows(rows *sql.Rows) (result0 *models.Question, err error) {
question := &models.Question{}
var contentJSON string
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
question.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
question.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
question.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
question.Scenario = scenario.String
}
if styleModifier.Valid {
question.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
question.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
question.TimeContext = timeContext.String
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return question, nil
}
// scanQuestionBasicFromRows scans a database rows into a models.Question struct (basic fields only)
func (s *QuestionService) scanQuestionBasicFromRows(rows *sql.Rows) (result0 *models.Question, err error) {
question := &models.Question{}
var contentJSON string
err = rows.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
)
if err != nil {
return nil, err
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return question, nil
}
// scanQuestionWithStatsFromRows scans a database rows into a QuestionWithStats struct
func (s *QuestionService) scanQuestionWithStatsFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&questionWithStats.UserCount,
)
if err != nil {
return nil, err
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return questionWithStats, nil
}
// scanQuestionWithStatsAndAllFieldsFromRows scans a database rows into a QuestionWithStats struct (with all fields)
func (s *QuestionService) scanQuestionWithStatsAndAllFieldsFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&questionWithStats.UserCount,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
questionWithStats.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
questionWithStats.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
questionWithStats.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
questionWithStats.Scenario = scenario.String
}
if styleModifier.Valid {
questionWithStats.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
questionWithStats.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
questionWithStats.TimeContext = timeContext.String
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return questionWithStats, nil
}
// scanQuestionWithPriorityAndStatsFromRows scans a database rows into a QuestionWithStats struct (with priority and stats)
func (s *QuestionService) scanQuestionWithPriorityAndStatsFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var priorityScore float64
var timesAnswered int
var lastAnsweredAt sql.NullTime
var confidenceLevel sql.NullInt32
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
&priorityScore,
×Answered,
&lastAnsweredAt,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&confidenceLevel,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
questionWithStats.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
questionWithStats.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
questionWithStats.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
questionWithStats.Scenario = scenario.String
}
if styleModifier.Valid {
questionWithStats.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
questionWithStats.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
questionWithStats.TimeContext = timeContext.String
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
// Set confidence level if it exists
if confidenceLevel.Valid {
level := int(confidenceLevel.Int32)
questionWithStats.ConfidenceLevel = &level
}
// Populate per-user times answered from the scanned value
questionWithStats.TimesAnswered = timesAnswered
return questionWithStats, nil
}
// scanQuestionWithStatsAndReportersFromRows scans a database rows into a QuestionWithStats struct (with reporter information)
func (s *QuestionService) scanQuestionWithStatsAndReportersFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var reporters sql.NullString
var reportReasons sql.NullString
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&reporters,
&reportReasons,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
questionWithStats.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
questionWithStats.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
questionWithStats.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
questionWithStats.Scenario = scenario.String
}
if styleModifier.Valid {
questionWithStats.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
questionWithStats.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
questionWithStats.TimeContext = timeContext.String
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
// Store reporter information
if reporters.Valid && reporters.String != "" {
questionWithStats.Reporters = reporters.String
}
// Store report reasons information
if reportReasons.Valid && reportReasons.String != "" {
questionWithStats.ReportReasons = reportReasons.String
}
return questionWithStats, nil
}
// getQuestionByQuery is a shared method for getting a question by any query
func (s *QuestionService) getQuestionByQuery(ctx context.Context, query string, args ...interface{}) (result0 *models.Question, err error) {
row := s.db.QueryRowContext(ctx, query, args...)
var question *models.Question
question, err = s.scanQuestionFromRow(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, sql.ErrNoRows // Propagate sql.ErrNoRows for not found
}
return nil, err
}
return question, nil
}
// NewQuestionServiceWithLogger creates a new QuestionService instance with logger
func NewQuestionServiceWithLogger(db *sql.DB, learningService *LearningService, cfg *config.Config, logger *observability.Logger) *QuestionService {
if db == nil {
panic("database connection cannot be nil")
}
if logger == nil {
panic("logger cannot be nil")
}
return &QuestionService{
db: db,
learningService: learningService,
logger: logger,
cfg: cfg,
}
}
// getDailyRepeatAvoidDays returns the configured number of days to avoid repeating
// questions in daily assignments. Defaults to 7 when not configured or invalid.
func (s *QuestionService) getDailyRepeatAvoidDays() int {
if s.cfg != nil {
if days := s.cfg.Server.DailyRepeatAvoidDays; days > 0 {
return days
}
}
return 7
}
// SaveQuestion saves a question to the database
func (s *QuestionService) SaveQuestion(ctx context.Context, question *models.Question) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "save_question", observability.AttributeQuestion(question))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var contentJSON []byte
contentJSONStr, err := question.MarshalContentToJSON()
if err != nil {
return contextutils.WrapError(err, "failed to marshal question content")
}
contentJSON = []byte(contentJSONStr)
if err != nil {
return contextutils.WrapError(err, "failed to marshal question content")
}
if question.Status == "" {
question.Status = models.QuestionStatusActive
}
query := `
INSERT INTO questions (type, language, level, difficulty_score, content, correct_answer, explanation, status, topic_category, grammar_focus, vocabulary_domain, scenario, style_modifier, difficulty_modifier, time_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) RETURNING id
`
var id int
err = s.db.QueryRowContext(ctx, query,
question.Type,
question.Language,
question.Level,
question.DifficultyScore,
string(contentJSON),
question.CorrectAnswer,
question.Explanation,
question.Status,
question.TopicCategory,
question.GrammarFocus,
question.VocabularyDomain,
question.Scenario,
question.StyleModifier,
question.DifficultyModifier,
question.TimeContext,
).Scan(&id)
if err != nil {
return contextutils.WrapError(err, "failed to save question to database")
}
question.ID = id
return nil
}
// AssignQuestionToUser assigns a question to a user
func (s *QuestionService) AssignQuestionToUser(ctx context.Context, questionID, userID int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "assign_question_to_user", observability.AttributeQuestionID(questionID), observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
INSERT INTO user_questions (user_id, question_id)
VALUES ($1, $2)
ON CONFLICT (user_id, question_id) DO NOTHING
`
_, err = s.db.ExecContext(ctx, query, userID, questionID)
return contextutils.WrapError(err, "failed to assign question to user")
}
// GetQuestionByID retrieves a question by its ID
func (s *QuestionService) GetQuestionByID(ctx context.Context, id int) (result0 *models.Question, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_by_id", observability.AttributeQuestionID(id))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := fmt.Sprintf("SELECT %s FROM questions WHERE id = $1", questionSelectFields)
return s.getQuestionByQuery(ctx, query, id)
}
// GetQuestionWithStats retrieves a question by its ID with response statistics
func (s *QuestionService) GetQuestionWithStats(ctx context.Context, id int) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_with_stats", observability.AttributeQuestionID(id))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(SUM(CASE WHEN ur.is_correct = true THEN 1 ELSE 0 END), 0) as correct_count,
COALESCE(SUM(CASE WHEN ur.is_correct = false THEN 1 ELSE 0 END), 0) as incorrect_count,
COALESCE(COUNT(ur.id), 0) as total_responses
FROM questions q
LEFT JOIN user_responses ur ON q.id = ur.question_id
WHERE q.id = $1
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context
`
q := &models.Question{}
stats := &QuestionWithStats{Question: q}
var contentJSON string
err = s.db.QueryRowContext(ctx, query, id).Scan(
&q.ID, &q.Type, &q.Language, &q.Level, &q.DifficultyScore,
&contentJSON, &q.CorrectAnswer, &q.Explanation, &q.CreatedAt, &q.Status,
&q.TopicCategory, &q.GrammarFocus, &q.VocabularyDomain, &q.Scenario, &q.StyleModifier, &q.DifficultyModifier, &q.TimeContext,
&stats.CorrectCount, &stats.IncorrectCount, &stats.TotalResponses,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, contextutils.ErrQuestionNotFound
}
return nil, contextutils.WrapError(err, "failed to get question with stats")
}
// Parse JSON content
if err := q.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, contextutils.WrapError(err, "failed to unmarshal question content")
}
return stats, nil
}
// GetQuestionsByFilter retrieves questions matching the specified criteria
func (s *QuestionService) GetQuestionsByFilter(ctx context.Context, userID int, language, level string, questionType models.QuestionType, limit int) (result0 []models.Question, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_questions_by_filter", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(questionType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var query string
var args []interface{}
if questionType == "" {
// Don't filter by type if questionType is empty
query = `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1 AND q.language = $2 AND q.level = $3 AND q.status = $4
ORDER BY RANDOM()
LIMIT $5
`
args = []interface{}{userID, language, level, models.QuestionStatusActive, limit}
} else {
// Filter by specific type
query = `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1 AND q.language = $2 AND q.level = $3 AND q.type = $4 AND q.status = $5
ORDER BY RANDOM()
LIMIT $6
`
args = []interface{}{userID, language, level, questionType, models.QuestionStatusActive, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query questions by filter")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []models.Question
for rows.Next() {
question, err := s.scanQuestionBasicFromRows(rows)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan question from rows")
}
questions = append(questions, *question)
}
return questions, nil
}
// ReportedQuestionWithUser represents a reported question with user information
type ReportedQuestionWithUser struct {
*models.Question
ReportedByUsername string `json:"reported_by_username"`
TotalResponses int `json:"total_responses"`
}
// GetReportedQuestions retrieves all questions that have been reported as problematic
func (s *QuestionService) GetReportedQuestions(ctx context.Context) (result0 []*ReportedQuestionWithUser, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_reported_questions")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status, u.username,
COALESCE(COUNT(ur.id), 0) as total_responses
FROM questions q
LEFT JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN users u ON uq.user_id = u.id
LEFT JOIN user_responses ur ON q.id = ur.question_id
WHERE q.status = $1
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status, u.username
ORDER BY q.created_at DESC
`
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query, models.QuestionStatusReported)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query reported questions")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*ReportedQuestionWithUser
for rows.Next() {
var question models.Question
var reportedByUsername sql.NullString
var contentJSON string
var totalResponses int
err = rows.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&reportedByUsername,
&totalResponses,
)
if err != nil {
return nil, err
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
username := ""
if reportedByUsername.Valid {
username = reportedByUsername.String
}
reportedQuestion := &ReportedQuestionWithUser{
Question: &question,
ReportedByUsername: username,
TotalResponses: totalResponses,
}
questions = append(questions, reportedQuestion)
}
return questions, nil
}
// MarkQuestionAsFixed marks a reported question as fixed and puts it back in rotation
func (s *QuestionService) MarkQuestionAsFixed(ctx context.Context, questionID int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "mark_question_as_fixed", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `UPDATE questions SET status = $1 WHERE id = $2`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, models.QuestionStatusActive, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to mark question as fixed")
}
// Check if the question was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
return nil
}
// UpdateQuestion updates a question's content, correct answer, and explanation
func (s *QuestionService) UpdateQuestion(ctx context.Context, questionID int, content map[string]interface{}, correctAnswerIndex int, explanation string) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "update_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var contentJSON []byte
// Marshal provided content map via a temporary Question instance to reuse method
tempQ := &models.Question{Content: content}
contentJSONStr, err := tempQ.MarshalContentToJSON()
if err != nil {
return contextutils.WrapError(err, "failed to marshal content JSON")
}
contentJSON = []byte(contentJSONStr)
if err != nil {
return contextutils.WrapError(err, "failed to marshal content JSON")
}
query := `UPDATE questions SET content = $1, correct_answer = $2, explanation = $3 WHERE id = $4`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, string(contentJSON), correctAnswerIndex, explanation, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to update question")
}
// Check if the question was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
return nil
}
// DeleteQuestion permanently deletes a question from the database
func (s *QuestionService) DeleteQuestion(ctx context.Context, questionID int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "delete_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// First, delete associated user responses
deleteResponsesQuery := `DELETE FROM user_responses WHERE question_id = $1`
_, err = s.db.ExecContext(ctx, deleteResponsesQuery, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to delete associated user responses")
}
// Then delete the question itself
deleteQuestionQuery := `DELETE FROM questions WHERE id = $1`
var result sql.Result
result, err = s.db.ExecContext(ctx, deleteQuestionQuery, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to delete question")
}
// Check if the question was actually deleted
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
return nil
}
// ReportQuestion marks a question as reported/problematic by a specific user
func (s *QuestionService) ReportQuestion(ctx context.Context, questionID, userID int, reportReason string) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "report_question", observability.AttributeQuestionID(questionID), observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Start a transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Check if question exists first
var questionExists bool
err = tx.QueryRowContext(ctx, `SELECT EXISTS(SELECT 1 FROM questions WHERE id = $1)`, questionID).Scan(&questionExists)
if err != nil {
return contextutils.WrapError(err, "failed to check if question exists")
}
if !questionExists {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with id %d not found", questionID)
}
// Update question status to reported
updateQuery := `UPDATE questions SET status = $1 WHERE id = $2`
var result sql.Result
result, err = tx.ExecContext(ctx, updateQuery, models.QuestionStatusReported, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to update question status")
}
// Check if the question was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
// Use provided report reason or default message
reason := reportReason
if reason == "" {
reason = "Question reported by user"
}
// Create or update a report record: if the same user reports the same question again,
// update the report_reason to the new value instead of doing nothing. Also update created_at
// so admin views show the time of the latest report by that user.
reportQuery := `INSERT INTO question_reports (question_id, reported_by_user_id, report_reason) VALUES ($1, $2, $3) ON CONFLICT (question_id, reported_by_user_id) DO UPDATE SET report_reason = EXCLUDED.report_reason, created_at = now()`
_, err = tx.ExecContext(ctx, reportQuery, questionID, userID, reason)
if err != nil {
return contextutils.WrapError(err, "failed to create question report")
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit transaction")
}
return nil
}
// GetNextQuestion gets the next question for a user based on usage count and availability
func (s *QuestionService) GetNextQuestion(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_next_question", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Use priority-based selection with stats included
return s.getNextQuestionWithPriority(ctx, userID, language, level, qType)
}
// getNextQuestionWithPriority implements priority-based question selection with stats
func (s *QuestionService) getNextQuestionWithPriority(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_next_question_with_priority", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user preferences
var prefs *models.UserLearningPreferences
prefs, err = s.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to get user preferences", map[string]interface{}{"user_id": userID, "error": err.Error()})
// Fall back to default preferences
prefs = s.learningService.GetDefaultLearningPreferences()
}
// Get available questions with priority scores and stats
var questions []*QuestionWithStats
questions, err = s.getAvailableQuestionsWithPriority(ctx, userID, language, level, qType, prefs)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get available questions")
}
if len(questions) == 0 {
// Fallback: try to get a random global question and assign it to the user
globalQ, err := s.GetRandomGlobalQuestionForUser(ctx, userID, language, level, qType)
if err != nil {
return nil, contextutils.WrapError(err, "no personalized questions, and failed to get global fallback question")
}
if globalQ != nil {
return globalQ, nil
}
return nil, nil // No questions available at all
}
// Apply FreshQuestionRatio logic (NEW)
selectedQuestion, err := s.selectQuestionWithFreshnessRatio(questions, prefs.FreshQuestionRatio)
if err != nil {
return nil, contextutils.WrapError(err, "failed to select question with freshness ratio")
}
// Return the selected question with stats (already included)
return selectedQuestion, nil
}
// GetAdaptiveQuestionsForDaily selects multiple adaptive questions for daily assignments
func (s *QuestionService) GetAdaptiveQuestionsForDaily(ctx context.Context, userID int, language, level string, limit int) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_adaptive_questions_for_daily")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user learning preferences
prefs, err := s.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to get user learning preferences, using defaults", map[string]interface{}{
"user_id": userID, "error": err.Error(),
})
prefs = &models.UserLearningPreferences{
FreshQuestionRatio: 0.7,
}
}
var selectedQuestions []*QuestionWithStats
selectedQuestionIDs := make(map[int]bool) // Track selected question IDs to prevent duplicates
// Select questions across different types to provide variety
questionTypes := []models.QuestionType{models.Vocabulary, models.FillInBlank, models.QuestionAnswer, models.ReadingComprehension}
// Calculate how many questions to select from each type
questionsPerType := limit / len(questionTypes)
remainingQuestions := limit % len(questionTypes)
for i, qType := range questionTypes {
// Calculate how many questions to get for this type
currentLimit := questionsPerType
if i < remainingQuestions {
currentLimit++ // Distribute remaining questions evenly
}
if currentLimit == 0 {
continue
}
// Get available questions for DAILY with 2-day recent-correct exclusion
questions, err := s.getAvailableQuestionsForDailyWithPriority(ctx, userID, language, level, qType, prefs)
if err != nil {
s.logger.Warn(ctx, "Failed to get questions for type", map[string]interface{}{
"user_id": userID, "type": qType, "error": err.Error(),
})
continue
}
// Filter out questions that have already been selected
var availableQuestions []*QuestionWithStats
for _, q := range questions {
if !selectedQuestionIDs[q.ID] {
availableQuestions = append(availableQuestions, q)
}
}
if len(availableQuestions) == 0 {
// Try to get a global fallback question for this type
globalQ, err := s.GetRandomGlobalQuestionForUser(ctx, userID, language, level, qType)
if err != nil {
s.logger.Warn(ctx, "Failed to get global fallback question", map[string]interface{}{
"user_id": userID, "type": qType, "error": err.Error(),
})
continue
}
if globalQ != nil && !selectedQuestionIDs[globalQ.ID] {
selectedQuestions = append(selectedQuestions, globalQ)
selectedQuestionIDs[globalQ.ID] = true
s.logger.Info(ctx, "Added global fallback question", map[string]interface{}{
"user_id": userID, "type": qType, "question_id": globalQ.ID,
})
}
continue
}
// Select questions for this type using adaptive selection
s.logger.Info(ctx, "Starting selection for question type", map[string]interface{}{
"user_id": userID, "type": qType, "current_limit": currentLimit, "available_questions": len(availableQuestions),
})
questionsSelected := 0
remainingQuestionsForType := availableQuestions
for j := 0; j < currentLimit && len(remainingQuestionsForType) > 0; j++ {
// Apply freshness ratio logic for each selection
selectedQuestion, err := s.selectQuestionWithFreshnessRatio(remainingQuestionsForType, prefs.FreshQuestionRatio)
if err != nil {
s.logger.Warn(ctx, "Failed to select question with freshness ratio", map[string]interface{}{
"user_id": userID, "type": qType, "error": err.Error(),
})
// Fallback to simple random selection
if len(remainingQuestionsForType) > 0 {
selectedQuestion = remainingQuestionsForType[rand.Intn(len(remainingQuestionsForType))]
} else {
break
}
}
if selectedQuestion != nil && !selectedQuestionIDs[selectedQuestion.ID] {
selectedQuestions = append(selectedQuestions, selectedQuestion)
selectedQuestionIDs[selectedQuestion.ID] = true
questionsSelected++
// Remove the selected question from the remaining pool
var newRemainingQuestions []*QuestionWithStats
for _, q := range remainingQuestionsForType {
if q.ID != selectedQuestion.ID {
newRemainingQuestions = append(newRemainingQuestions, q)
}
}
remainingQuestionsForType = newRemainingQuestions
s.logger.Info(ctx, "Successfully selected question", map[string]interface{}{
"user_id": userID, "type": qType, "iteration": j, "question_id": selectedQuestion.ID,
"total_selected": len(selectedQuestions),
})
} else {
s.logger.Warn(ctx, "Failed to select question for type", map[string]interface{}{
"user_id": userID, "type": qType, "iteration": j, "current_limit": currentLimit,
"selected_question_nil": selectedQuestion == nil,
"already_selected": selectedQuestion != nil && selectedQuestionIDs[selectedQuestion.ID],
})
// Remove the question from the pool even if it was already selected
if selectedQuestion != nil {
var newRemainingQuestions []*QuestionWithStats
for _, q := range remainingQuestionsForType {
if q.ID != selectedQuestion.ID {
newRemainingQuestions = append(newRemainingQuestions, q)
}
}
remainingQuestionsForType = newRemainingQuestions
}
}
}
// If we didn't select enough questions for this type, try simple selection from all available questions
if questionsSelected < currentLimit {
s.logger.Info(ctx, "Using simple selection to fill remaining slots", map[string]interface{}{
"user_id": userID, "type": qType, "questions_selected": questionsSelected, "current_limit": currentLimit,
})
// Get all questions for this type again and filter out already selected ones
allQuestionsForType, err := s.getAvailableQuestionsForDailyWithPriority(ctx, userID, language, level, qType, prefs)
if err == nil {
for _, q := range allQuestionsForType {
if !selectedQuestionIDs[q.ID] && questionsSelected < currentLimit {
selectedQuestions = append(selectedQuestions, q)
selectedQuestionIDs[q.ID] = true
questionsSelected++
}
}
}
}
s.logger.Info(ctx, "Completed selection for question type", map[string]interface{}{
"user_id": userID, "type": qType, "questions_selected": questionsSelected, "target": currentLimit,
})
}
// If we don't have enough questions, fill with random questions from any type
if len(selectedQuestions) < limit {
remainingNeeded := limit - len(selectedQuestions)
s.logger.Info(ctx, "Not enough questions from type-based selection, using fallback", map[string]interface{}{
"user_id": userID, "selected_count": len(selectedQuestions), "limit": limit, "remaining_needed": remainingNeeded,
})
// Get all available questions by trying each question type
var allQuestions []*QuestionWithStats
questionIDMap := make(map[int]bool) // Track seen question IDs to avoid duplicates
for _, qType := range questionTypes {
questions, err := s.getAvailableQuestionsForDailyWithPriority(ctx, userID, language, level, qType, prefs)
if err == nil {
for _, q := range questions {
if !questionIDMap[q.ID] && !selectedQuestionIDs[q.ID] {
allQuestions = append(allQuestions, q)
questionIDMap[q.ID] = true
}
}
}
}
s.logger.Info(ctx, "Fallback questions available", map[string]interface{}{
"user_id": userID, "all_questions_count": len(allQuestions),
})
if len(allQuestions) > 0 {
// Select random questions to fill the remaining slots
for i := 0; i < remainingNeeded && i < len(allQuestions); i++ {
selectedQuestion, err := s.selectQuestionWithFreshnessRatio(allQuestions, prefs.FreshQuestionRatio)
if err != nil {
s.logger.Warn(ctx, "Failed to select question with freshness ratio in fallback", map[string]interface{}{
"user_id": userID, "error": err.Error(),
})
// Fallback to simple random selection
if len(allQuestions) > 0 {
selectedQuestion = allQuestions[rand.Intn(len(allQuestions))]
} else {
break
}
}
if selectedQuestion != nil && !selectedQuestionIDs[selectedQuestion.ID] {
selectedQuestions = append(selectedQuestions, selectedQuestion)
selectedQuestionIDs[selectedQuestion.ID] = true
// Remove the selected question from the pool
var newAllQuestions []*QuestionWithStats
for _, q := range allQuestions {
if q.ID != selectedQuestion.ID {
newAllQuestions = append(newAllQuestions, q)
}
}
allQuestions = newAllQuestions
} else if selectedQuestion != nil {
// Remove the question from the pool even if it was already selected
var newAllQuestions []*QuestionWithStats
for _, q := range allQuestions {
if q.ID != selectedQuestion.ID {
newAllQuestions = append(newAllQuestions, q)
}
}
allQuestions = newAllQuestions
}
}
}
}
// Ensure we don't exceed the limit
if len(selectedQuestions) > limit {
selectedQuestions = selectedQuestions[:limit]
}
// Final duplicate check - this should never happen but provides extra safety
finalSelectedQuestions := make([]*QuestionWithStats, 0, len(selectedQuestions))
finalSelectedIDs := make(map[int]bool)
for _, q := range selectedQuestions {
if !finalSelectedIDs[q.ID] {
finalSelectedQuestions = append(finalSelectedQuestions, q)
finalSelectedIDs[q.ID] = true
} else {
s.logger.Warn(ctx, "Duplicate question detected in final selection", map[string]interface{}{
"user_id": userID, "question_id": q.ID,
})
}
}
// Interleave selected questions by type to avoid bias toward types that were
// selected earlier in the algorithm. This ensures that when callers slice the
// returned list (e.g., to meet a smaller goal), later types like
// ReadingComprehension are not systematically excluded.
typeBuckets := make(map[models.QuestionType][]*QuestionWithStats)
var typeOrder []models.QuestionType
for _, q := range finalSelectedQuestions {
if _, ok := typeBuckets[q.Type]; !ok {
typeOrder = append(typeOrder, q.Type)
}
typeBuckets[q.Type] = append(typeBuckets[q.Type], q)
}
interleaved := make([]*QuestionWithStats, 0, len(finalSelectedQuestions))
for len(interleaved) < len(finalSelectedQuestions) {
added := false
for _, t := range typeOrder {
if len(typeBuckets[t]) > 0 {
interleaved = append(interleaved, typeBuckets[t][0])
typeBuckets[t] = typeBuckets[t][1:]
added = true
if len(interleaved) >= len(finalSelectedQuestions) {
break
}
}
}
if !added {
break
}
}
finalSelectedQuestions = interleaved
s.logger.Info(ctx, "Selected adaptive questions for daily assignment", map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
"requested_limit": limit,
"selected_count": len(finalSelectedQuestions),
"duplicates_removed": len(selectedQuestions) - len(finalSelectedQuestions),
})
return finalSelectedQuestions, nil
}
// GetQuestionStats returns basic statistics about questions in the system
func (s *QuestionService) GetQuestionStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
stats := make(map[string]interface{})
// Total questions
var totalQuestions int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM questions").Scan(&totalQuestions)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get total questions count")
}
stats["total_questions"] = totalQuestions
// Questions by type
typeQuery := `
SELECT type, COUNT(*) as count
FROM questions
GROUP BY type
`
rows, err := s.db.QueryContext(ctx, typeQuery)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query questions by type")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByType := make(map[string]int)
for rows.Next() {
var qType string
var count int
if err := rows.Scan(&qType, &count); err != nil {
return nil, contextutils.WrapError(err, "failed to scan question type count")
}
questionsByType[qType] = count
}
stats["questions_by_type"] = questionsByType
// Questions by level
levelQuery := `
SELECT level, COUNT(*) as count
FROM questions
GROUP BY level
`
rows, err = s.db.QueryContext(ctx, levelQuery)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query questions by level")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByLevel := make(map[string]int)
for rows.Next() {
var level string
var count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
questionsByLevel[level] = count
}
stats["questions_by_level"] = questionsByLevel
return stats, nil
}
// GetDetailedQuestionStats returns detailed statistics about questions
func (s *QuestionService) GetDetailedQuestionStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_detailed_question_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
stats := make(map[string]interface{})
// Total questions
var totalQuestions int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM questions").Scan(&totalQuestions)
if err != nil {
return nil, err
}
stats["total_questions"] = totalQuestions
// Questions by language, level, and type combination
detailQuery := `
SELECT language, level, type, COUNT(*) as count
FROM questions
GROUP BY language, level, type
ORDER BY language, level, type
`
rows, err := s.db.QueryContext(ctx, detailQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
// Create nested structure: language -> level -> type -> count
questionsByDetail := make(map[string]map[string]map[string]int)
for rows.Next() {
var language, level, qType string
var count int
if err := rows.Scan(&language, &level, &qType, &count); err != nil {
return nil, err
}
if questionsByDetail[language] == nil {
questionsByDetail[language] = make(map[string]map[string]int)
}
if questionsByDetail[language][level] == nil {
questionsByDetail[language][level] = make(map[string]int)
}
questionsByDetail[language][level][qType] = count
}
stats["questions_by_detail"] = questionsByDetail
// Questions by language
languageQuery := `
SELECT language, COUNT(*) as count
FROM questions
GROUP BY language
`
rows, err = s.db.QueryContext(ctx, languageQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByLanguage := make(map[string]int)
for rows.Next() {
var language string
var count int
if err := rows.Scan(&language, &count); err != nil {
return nil, err
}
questionsByLanguage[language] = count
}
stats["questions_by_language"] = questionsByLanguage
// Questions by type
typeQuery := `
SELECT type, COUNT(*) as count
FROM questions
GROUP BY type
`
rows, err = s.db.QueryContext(ctx, typeQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByType := make(map[string]int)
for rows.Next() {
var qType string
var count int
if err := rows.Scan(&qType, &count); err != nil {
return nil, err
}
questionsByType[qType] = count
}
stats["questions_by_type"] = questionsByType
// Questions by level
levelQuery := `
SELECT level, COUNT(*) as count
FROM questions
GROUP BY level
`
rows, err = s.db.QueryContext(ctx, levelQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByLevel := make(map[string]int)
for rows.Next() {
var level string
var count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
questionsByLevel[level] = count
}
stats["questions_by_level"] = questionsByLevel
return stats, nil
}
// GetRecentQuestionContentsForUser retrieves recent question contents for a user
func (s *QuestionService) GetRecentQuestionContentsForUser(ctx context.Context, userID, limit int) (result0 []string, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_recent_question_contents_for_user", observability.AttributeUserID(userID), observability.AttributeLimit(limit))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT DISTINCT q.content
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
JOIN user_questions uq ON q.id = uq.question_id
WHERE ur.user_id = $1 AND uq.user_id = $2
ORDER BY q.content DESC
LIMIT $3
`
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query, userID, userID, limit)
if err != nil {
return []string{}, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var contents []string
for rows.Next() {
var content string
if err := rows.Scan(&content); err != nil {
return []string{}, err
}
contents = append(contents, content)
}
// Ensure we always return an empty slice instead of nil
if contents == nil {
contents = []string{}
}
return contents, nil
}
// GetUserQuestions retrieves actual questions for a user (not just content)
func (s *QuestionService) GetUserQuestions(ctx context.Context, userID, limit int) (result0 []*models.Question, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_questions", observability.AttributeUserID(userID), observability.AttributeLimit(limit))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status, q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1
ORDER BY q.created_at DESC
LIMIT $2
`
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*models.Question
for rows.Next() {
question, err := s.scanQuestionFromRows(rows)
if err != nil {
return nil, err
}
questions = append(questions, question)
}
return questions, nil
}
// GetUserQuestionsWithStats retrieves questions for a user with response statistics
func (s *QuestionService) GetUserQuestionsWithStats(ctx context.Context, userID, limit int) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_questions_with_stats", observability.AttributeUserID(userID), observability.AttributeLimit(limit))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
COALESCE(SUM(CASE WHEN ur.is_correct = true THEN 1 ELSE 0 END), 0) as correct_count,
COALESCE(SUM(CASE WHEN ur.is_correct = false THEN 1 ELSE 0 END), 0) as incorrect_count,
COALESCE(COUNT(ur.id), 0) as total_responses,
COALESCE(uq_stats.user_count, 0) as user_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN user_responses ur ON q.id = ur.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(*) as user_count
FROM user_questions
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
WHERE uq.user_id = $1
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
uq_stats.user_count
ORDER BY q.created_at DESC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithStatsFromRows(rows)
if err != nil {
return nil, err
}
questions = append(questions, questionWithStats)
}
if err = rows.Err(); err != nil {
return nil, err
}
return questions, nil
}
// QuestionWithStats represents a question with response statistics
type QuestionWithStats struct {
*models.Question
CorrectCount int `json:"correct_count"`
IncorrectCount int `json:"incorrect_count"`
TotalResponses int `json:"total_responses"`
// TimesAnswered tracks how many times THIS user answered the question (per-user)
TimesAnswered int `json:"times_answered"`
UserCount int `json:"user_count"`
Reporters string `json:"reporters,omitempty"`
ReportReasons string `json:"report_reasons,omitempty"`
ConfidenceLevel *int `json:"confidence_level,omitempty"`
}
// GetQuestionsPaginated retrieves questions with pagination and response statistics
func (s *QuestionService) GetQuestionsPaginated(ctx context.Context, userID, page, pageSize int, search, typeFilter, statusFilter string) (result0 []*QuestionWithStats, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_questions_paginated", observability.AttributeUserID(userID), observability.AttributePage(page), observability.AttributePageSize(pageSize), observability.AttributeSearch(search), observability.AttributeTypeFilter(typeFilter), observability.AttributeStatusFilter(statusFilter))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Build WHERE clause with filters using parameterized queries
whereConditions := []string{"uq.user_id = $1"}
args := []interface{}{userID}
argCount := 1
// Add search filter
if search != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("(q.content::text ILIKE $%d OR q.explanation ILIKE $%d)", argCount, argCount))
args = append(args, "%"+search+"%")
}
// Add type filter
if typeFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.type = $%d", argCount))
args = append(args, typeFilter)
}
// Add status filter
if statusFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.status = $%d", argCount))
args = append(args, statusFilter)
}
// Join all conditions
whereClause := "WHERE " + strings.Join(whereConditions, " AND ")
// First get the total count with filters
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM questions q JOIN user_questions uq ON q.id = uq.question_id %s", whereClause)
var totalCount int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount)
if err != nil {
return nil, 0, err
}
// Calculate offset
offset := (page - 1) * pageSize
// Build main query with pagination
query := fmt.Sprintf(`
SELECT
q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(SUM(CASE WHEN ur.is_correct = true THEN 1 ELSE 0 END), 0) as correct_count,
COALESCE(SUM(CASE WHEN ur.is_correct = false THEN 1 ELSE 0 END), 0) as incorrect_count,
COALESCE(COUNT(ur.id), 0) as total_responses,
COALESCE(uq_stats.user_count, 0) as user_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN user_responses ur ON q.id = ur.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(*) as user_count
FROM user_questions
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
%s
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
uq_stats.user_count
ORDER BY q.id DESC
LIMIT $%d OFFSET $%d
`, whereClause, argCount+1, argCount+2)
// Add pagination parameters
args = append(args, pageSize, offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithStatsAndAllFieldsFromRows(rows)
if err != nil {
return nil, 0, err
}
questions = append(questions, questionWithStats)
}
if err = rows.Err(); err != nil {
return nil, 0, err
}
return questions, totalCount, nil
}
// PRIORITY-BASED QUESTION SELECTION METHODS
// getAvailableQuestionsWithPriority retrieves available questions with priority scores and stats
func (s *QuestionService) getAvailableQuestionsWithPriority(ctx context.Context, userID int, language, level string, qType models.QuestionType, _ *models.UserLearningPreferences) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_available_questions_with_priority", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Build SQL query with priority scoring and stats
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(qps.priority_score, 100.0) as priority_score,
COALESCE(uq_stats.times_answered, 0) as times_answered,
uq_stats.last_answered_at,
COALESCE(stats.correct_count, 0) as correct_count,
COALESCE(stats.incorrect_count, 0) as incorrect_count,
COALESCE(stats.total_responses, 0) as total_responses,
uqm.confidence_level
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
LEFT JOIN (
SELECT question_id,
COUNT(*) as times_answered,
MAX(created_at) as last_answered_at
FROM user_responses
WHERE user_id = $1
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) stats ON q.id = stats.question_id
LEFT JOIN user_question_metadata uqm ON q.id = uqm.question_id AND uqm.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.status = 'active'
AND q.id NOT IN (
SELECT ur.question_id
FROM user_responses ur
WHERE ur.user_id = $1
AND ur.created_at > NOW() - INTERVAL '1 hour'
)
-- Exclude questions where the user's last 3 responses were all correct within the last 90 days
AND NOT EXISTS (
SELECT 1 FROM (
SELECT ur2.is_correct
FROM user_responses ur2
WHERE ur2.user_id = $1
AND ur2.question_id = q.id
AND ur2.created_at >= NOW() - INTERVAL '90 days'
ORDER BY ur2.created_at DESC
LIMIT 3
) recent_three
WHERE (SELECT COUNT(*) FROM (
SELECT 1 FROM (
SELECT ur3.is_correct
FROM user_responses ur3
WHERE ur3.user_id = $1
AND ur3.question_id = q.id
AND ur3.created_at >= NOW() - INTERVAL '90 days'
ORDER BY ur3.created_at DESC
LIMIT 3
) t WHERE t.is_correct = TRUE
) c) = 3
)
-- Exclude questions the user explicitly marked as known with max confidence (5)
-- within the last 60 days (approx. 2 months)
AND NOT EXISTS (
SELECT 1 FROM user_question_metadata uqm2
WHERE uqm2.user_id = $1
AND uqm2.question_id = q.id
AND uqm2.marked_as_known = TRUE
AND uqm2.confidence_level = 5
AND uqm2.marked_as_known_at >= NOW() - INTERVAL '60 days'
)
ORDER BY priority_score DESC, RANDOM()
LIMIT 50
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, qType)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to query questions: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithPriorityAndStatsFromRows(rows)
if err != nil {
s.logger.Error(ctx, "Error scanning question", err, map[string]interface{}{})
continue // Skip malformed rows
}
questions = append(questions, questionWithStats)
}
return questions, nil
}
// getAvailableQuestionsForDailyWithPriority applies daily-specific eligibility:
// exclude questions answered correctly within the last 2 days for the user.
func (s *QuestionService) getAvailableQuestionsForDailyWithPriority(ctx context.Context, userID int, language, level string, qType models.QuestionType, _ *models.UserLearningPreferences) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_available_questions_for_daily_with_priority", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
avoidDays := s.getDailyRepeatAvoidDays()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(qps.priority_score, 100.0) as priority_score,
COALESCE(uq_stats.times_answered, 0) as times_answered,
uq_stats.last_answered_at,
COALESCE(stats.correct_count, 0) as correct_count,
COALESCE(stats.incorrect_count, 0) as incorrect_count,
COALESCE(stats.total_responses, 0) as total_responses,
uqm.confidence_level
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
LEFT JOIN (
SELECT question_id,
COUNT(*) as times_answered,
MAX(created_at) as last_answered_at
FROM user_responses
WHERE user_id = $1
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) stats ON q.id = stats.question_id
LEFT JOIN user_question_metadata uqm ON q.id = uqm.question_id AND uqm.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.status = 'active'
AND NOT EXISTS (
SELECT 1
FROM user_responses ur
WHERE ur.user_id = $1
AND ur.question_id = q.id
AND ur.is_correct = TRUE
AND ur.created_at >= NOW() - ($5 || ' days')::interval
)
-- Exclude questions the user marked as known with confidence 5 within last 60 days
AND NOT EXISTS (
SELECT 1 FROM user_question_metadata uqm2
WHERE uqm2.user_id = $1
AND uqm2.question_id = q.id
AND uqm2.marked_as_known = TRUE
AND uqm2.confidence_level = 5
AND uqm2.marked_as_known_at >= NOW() - INTERVAL '60 days'
)
ORDER BY priority_score DESC, RANDOM()
LIMIT 50
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, qType, avoidDays)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to query questions (daily): %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithPriorityAndStatsFromRows(rows)
if err != nil {
s.logger.Error(ctx, "Error scanning question (daily)", err, map[string]interface{}{})
continue
}
questions = append(questions, questionWithStats)
}
return questions, nil
}
// selectQuestionWithWeightedRandomness selects a question using weighted random selection
func (s *QuestionService) selectQuestionWithWeightedRandomness(questions []*QuestionWithStats) (result0 *QuestionWithStats, err error) {
if len(questions) == 0 {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "no questions available")
}
// Use weighted random selection based on usage count (lower = higher priority)
totalWeight := 0.0
for _, q := range questions {
// Prefer per-user times answered when available
usageCount := q.TotalResponses
if q.TimesAnswered >= 0 {
usageCount = q.TimesAnswered
}
// Lower usage count = higher weight
weight := 1.0 / (float64(usageCount) + 1.0)
totalWeight += weight
}
// Handle edge case where all questions have zero weight or floating-point precision issues
if totalWeight <= 0 {
// If all questions have equal weight (e.g., all TotalResponses = 0), use simple random selection
return questions[rand.Intn(len(questions))], nil
}
target := rand.Float64() * totalWeight
currentWeight := 0.0
for _, q := range questions {
usageCount := q.TotalResponses
if q.TimesAnswered >= 0 {
usageCount = q.TimesAnswered
}
weight := 1.0 / (float64(usageCount) + 1.0)
currentWeight += weight
if currentWeight >= target {
return q, nil
}
}
// Fallback: if we reach the end without selecting (due to floating-point precision),
// return the last question or a random one
if len(questions) > 0 {
return questions[len(questions)-1], nil
}
return nil, contextutils.WrapError(contextutils.ErrInternalError, "failed to select question with weighted randomness")
}
// selectQuestionWithFreshnessRatio selects a question based on freshness ratio
func (s *QuestionService) selectQuestionWithFreshnessRatio(questions []*QuestionWithStats, freshnessRatio float64) (result0 *QuestionWithStats, err error) {
if len(questions) == 0 {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "no questions available")
}
// Separate fresh and review questions based on total responses
var freshQuestions []*QuestionWithStats
var reviewQuestions []*QuestionWithStats
for _, q := range questions {
// Consider fresh relative to this user (TimesAnswered==0). Fall back to TotalResponses if TimesAnswered not set.
isFresh := false
if q.TimesAnswered >= 0 {
isFresh = q.TimesAnswered == 0
} else {
isFresh = q.TotalResponses == 0
}
if isFresh {
freshQuestions = append(freshQuestions, q)
} else {
reviewQuestions = append(reviewQuestions, q)
}
}
// Use probabilistic selection based on the freshness ratio
var selectedQuestions []*QuestionWithStats
if len(freshQuestions) > 0 && len(reviewQuestions) > 0 {
// Both categories available - use probabilistic selection
if rand.Float64() < freshnessRatio {
selectedQuestions = freshQuestions
} else {
selectedQuestions = reviewQuestions
}
} else if len(freshQuestions) > 0 {
// Only fresh questions available
selectedQuestions = freshQuestions
} else if len(reviewQuestions) > 0 {
// Only review questions available
selectedQuestions = reviewQuestions
} else {
// Fallback to all questions if no separation possible
selectedQuestions = questions
}
if len(selectedQuestions) == 0 {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "no questions available after freshness filtering")
}
// Use weighted random selection within the chosen category
result, err := s.selectQuestionWithWeightedRandomness(selectedQuestions)
if err != nil {
// Log debug info about the selection failure
s.logger.Warn(context.Background(), "selectQuestionWithWeightedRandomness failed", map[string]interface{}{
"total_questions": len(questions),
"fresh_questions": len(freshQuestions),
"review_questions": len(reviewQuestions),
"selected_category_size": len(selectedQuestions),
"freshness_ratio": freshnessRatio,
"error": err.Error(),
})
}
return result, err
}
// GetUserQuestionCount returns the total number of questions available for a user
func (s *QuestionService) GetUserQuestionCount(ctx context.Context, userID int) (result0 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_question_count", observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(DISTINCT q.id)
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1 AND q.status = 'active'
`
var count int
err = s.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user question count: %w", err)
}
return count, nil
}
// GetUserResponseCount returns the total number of responses for a user
func (s *QuestionService) GetUserResponseCount(ctx context.Context, userID int) (result0 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_response_count", observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `SELECT COUNT(*) FROM user_responses WHERE user_id = $1`
var count int
err = s.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user response count: %w", err)
}
return count, nil
}
// GetUsersForQuestion returns the users assigned to a question, up to 5 users, and the total count
func (s *QuestionService) GetUsersForQuestion(ctx context.Context, questionID int) (result0 []*models.User, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_users_for_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// First get the total count
countQuery := `SELECT COUNT(*) FROM user_questions WHERE question_id = $1`
var totalCount int
err = s.db.QueryRowContext(ctx, countQuery, questionID).Scan(&totalCount)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user count for question: %w", err)
}
// Then get up to 5 users
usersQuery := `
SELECT u.id, u.username, u.email, u.timezone, u.password_hash, u.last_active,
u.preferred_language, u.current_level, u.ai_provider, u.ai_model,
u.ai_enabled, u.ai_api_key, u.created_at, u.updated_at
FROM users u
JOIN user_questions uq ON u.id = uq.user_id
WHERE uq.question_id = $1
ORDER BY u.username
LIMIT 5
`
rows, err := s.db.QueryContext(ctx, usersQuery, questionID)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get users for question: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var users []*models.User
for rows.Next() {
user := &models.User{}
err = rows.Scan(
&user.ID,
&user.Username,
&user.Email,
&user.Timezone,
&user.PasswordHash,
&user.LastActive,
&user.PreferredLanguage,
&user.CurrentLevel,
&user.AIProvider,
&user.AIModel,
&user.AIEnabled,
&user.AIAPIKey,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to scan user: %w", err)
}
users = append(users, user)
}
if err = rows.Err(); err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "error iterating users: %w", err)
}
// Ensure we always return an empty slice instead of nil
if users == nil {
users = make([]*models.User, 0)
}
return users, totalCount, nil
}
// Helper: scan a *sql.Row into a QuestionWithStats (for single-row queries)
func (s *QuestionService) scanQuestionWithPriorityAndStatsFromRow(row *sql.Row) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var priorityScore float64
var timesAnswered int
var lastAnsweredAt sql.NullTime
err = row.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&questionWithStats.TopicCategory,
&questionWithStats.GrammarFocus,
&questionWithStats.VocabularyDomain,
&questionWithStats.Scenario,
&questionWithStats.StyleModifier,
&questionWithStats.DifficultyModifier,
&questionWithStats.TimeContext,
&priorityScore,
×Answered,
&lastAnsweredAt,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
)
if err != nil {
return nil, err
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return questionWithStats, nil
}
// GetRandomGlobalQuestionForUser finds a random question from the global pool for the given language, level, and type that is not already assigned to the user, assigns it, and returns it.
func (s *QuestionService) GetRandomGlobalQuestionForUser(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_random_global_question_for_user", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
100.0 as priority_score, 0 as times_answered, NULL as last_answered_at, 0 as correct_count, 0 as incorrect_count, 0 as total_responses
FROM questions q
WHERE q.language = $1
AND q.level = $2
AND q.type = $3
AND q.status = 'active'
AND q.id NOT IN (
SELECT uq.question_id
FROM user_questions uq
WHERE uq.user_id = $4
)
-- Exclude questions the user marked as known with confidence 5 within last 60 days
AND NOT EXISTS (
SELECT 1 FROM user_question_metadata uqm2
WHERE uqm2.user_id = $4
AND uqm2.question_id = q.id
AND uqm2.marked_as_known = TRUE
AND uqm2.confidence_level = 5
AND uqm2.marked_as_known_at >= NOW() - INTERVAL '60 days'
)
ORDER BY RANDOM()
LIMIT 1
`
row := s.db.QueryRowContext(ctx, query, language, level, qType, userID)
questionWithStats, err := s.scanQuestionWithPriorityAndStatsFromRow(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // No global questions available
}
return nil, err
}
// Assign the question to the user
err = s.AssignQuestionToUser(ctx, questionWithStats.ID, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to assign global question to user", map[string]interface{}{"question_id": questionWithStats.ID, "user_id": userID, "error": err.Error()})
// Still return the question, but log the error
}
return questionWithStats, nil
}
// GetAllQuestionsPaginated returns all questions with pagination and filtering
func (s *QuestionService) GetAllQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, statusFilter, languageFilter, levelFilter string, userID *int) (result0 []*QuestionWithStats, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_all_questions_paginated")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Build the base query
baseQuery := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(ur_stats.correct_count, 0) as correct_count,
COALESCE(ur_stats.incorrect_count, 0) as incorrect_count,
COALESCE(ur_stats.total_responses, 0) as total_responses,
COALESCE(uq_stats.user_count, 0) as user_count
FROM questions q
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) ur_stats ON q.id = ur_stats.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(*) as user_count
FROM user_questions
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
WHERE 1=1
`
// Build the count query
countQuery := `
SELECT COUNT(*)
FROM questions q
WHERE 1=1
`
var args []interface{}
argIndex := 1
// Add filters
if search != "" {
searchCondition := ` AND (q.content::text ILIKE $` + strconv.Itoa(argIndex) + ` OR q.explanation ILIKE $` + strconv.Itoa(argIndex) + `)`
baseQuery += searchCondition
countQuery += searchCondition
args = append(args, "%"+search+"%")
argIndex++
}
if typeFilter != "" {
typeCondition := ` AND q.type = $` + strconv.Itoa(argIndex)
baseQuery += typeCondition
countQuery += typeCondition
args = append(args, typeFilter)
argIndex++
}
if statusFilter != "" {
statusCondition := ` AND q.status = $` + strconv.Itoa(argIndex)
baseQuery += statusCondition
countQuery += statusCondition
args = append(args, statusFilter)
argIndex++
}
if languageFilter != "" {
languageCondition := ` AND q.language = $` + strconv.Itoa(argIndex)
baseQuery += languageCondition
countQuery += languageCondition
args = append(args, languageFilter)
argIndex++
}
if levelFilter != "" {
levelCondition := ` AND q.level = $` + strconv.Itoa(argIndex)
baseQuery += levelCondition
countQuery += levelCondition
args = append(args, levelFilter)
argIndex++
}
if userID != nil {
userCondition := ` AND q.id IN (SELECT question_id FROM user_questions WHERE user_id = $` + strconv.Itoa(argIndex) + `)`
baseQuery += userCondition
countQuery += userCondition
args = append(args, *userID)
argIndex++
}
// Get total count
var total int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get total count: %w", err)
}
// Add pagination
offset := (page - 1) * pageSize
baseQuery += ` ORDER BY q.created_at DESC LIMIT $` + strconv.Itoa(argIndex) + ` OFFSET $` + strconv.Itoa(argIndex+1)
args = append(args, pageSize, offset)
// Execute the main query
rows, err := s.db.QueryContext(ctx, baseQuery, args...)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get questions: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
question, err := s.scanQuestionWithStatsAndAllFieldsFromRows(rows)
if err != nil {
return nil, 0, err
}
questions = append(questions, question)
}
return questions, total, nil
}
// GetReportedQuestionsPaginated returns reported questions with pagination and filtering
func (s *QuestionService) GetReportedQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, languageFilter, levelFilter string) (result0 []*QuestionWithStats, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_reported_questions_paginated")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Validate pagination parameters
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
// Build WHERE clause with filters using parameterized queries
whereConditions := []string{"q.status = 'reported'"}
args := []interface{}{}
argCount := 0
// Add search filter
if search != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("(q.content::text ILIKE $%d OR q.explanation ILIKE $%d)", argCount, argCount))
args = append(args, "%"+search+"%")
}
// Add type filter
if typeFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.type = $%d", argCount))
args = append(args, typeFilter)
}
// Add language filter
if languageFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.language = $%d", argCount))
args = append(args, languageFilter)
}
// Add level filter
if levelFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.level = $%d", argCount))
args = append(args, levelFilter)
}
// Join all conditions
whereClause := "WHERE " + strings.Join(whereConditions, " AND ")
// Build the count query
countQuery := fmt.Sprintf("SELECT COUNT(DISTINCT q.id) FROM questions q %s", whereClause)
var total int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get total count: %w", err)
}
// Calculate offset
offset := (page - 1) * pageSize
// Build main query with pagination
query := fmt.Sprintf(`
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(ur_stats.correct_count, 0) as correct_count,
COALESCE(ur_stats.incorrect_count, 0) as incorrect_count,
COALESCE(ur_stats.total_responses, 0) as total_responses,
STRING_AGG(DISTINCT u.username, ', ') as reporters,
STRING_AGG(DISTINCT qr.report_reason, ' | ') as report_reasons
FROM questions q
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) ur_stats ON q.id = ur_stats.question_id
LEFT JOIN question_reports qr ON q.id = qr.question_id
LEFT JOIN users u ON qr.reported_by_user_id = u.id
%s
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
ur_stats.correct_count, ur_stats.incorrect_count, ur_stats.total_responses
ORDER BY q.created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argCount+1, argCount+2)
// Add pagination parameters
args = append(args, pageSize, offset)
// Execute the main query
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
question, err := s.scanQuestionWithStatsAndReportersFromRows(rows)
if err != nil {
return nil, 0, err
}
questions = append(questions, question)
}
return questions, total, nil
}
// GetReportedQuestionsStats returns statistics about reported questions
func (s *QuestionService) GetReportedQuestionsStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_reported_questions_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
stats := make(map[string]interface{})
// Get total reported questions
var totalReported int
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM questions WHERE status = 'reported'`).Scan(&totalReported)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get total reported questions: %w", err)
}
stats["total_reported"] = totalReported
// Get reported questions by type
rows, err := s.db.QueryContext(ctx, `
SELECT type, COUNT(*) as count
FROM questions
WHERE status = 'reported'
GROUP BY type
ORDER BY count DESC
`)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions by type: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
reportedByType := make(map[string]int)
for rows.Next() {
var questionType string
var count int
if err := rows.Scan(&questionType, &count); err != nil {
return nil, err
}
reportedByType[questionType] = count
}
stats["reported_by_type"] = reportedByType
// Get reported questions by level
rows, err = s.db.QueryContext(ctx, `
SELECT level, COUNT(*) as count
FROM questions
WHERE status = 'reported'
GROUP BY level
ORDER BY count DESC
`)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions by level: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
reportedByLevel := make(map[string]int)
for rows.Next() {
var level string
var count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
reportedByLevel[level] = count
}
stats["reported_by_level"] = reportedByLevel
// Get reported questions by language
rows, err = s.db.QueryContext(ctx, `
SELECT language, COUNT(*) as count
FROM questions
WHERE status = 'reported'
GROUP BY language
ORDER BY count DESC
`)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions by language: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
reportedByLanguage := make(map[string]int)
for rows.Next() {
var language string
var count int
if err := rows.Scan(&language, &count); err != nil {
return nil, err
}
reportedByLanguage[language] = count
}
stats["reported_by_language"] = reportedByLanguage
return stats, nil
}
// AssignUsersToQuestion assigns multiple users to a question
func (s *QuestionService) AssignUsersToQuestion(ctx context.Context, questionID int, userIDs []int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "assign_users_to_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Start a transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Prepare the insert statement
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO user_questions (user_id, question_id, created_at)
VALUES ($1, $2, NOW())
ON CONFLICT (user_id, question_id) DO NOTHING
`)
if err != nil {
return contextutils.WrapError(err, "failed to prepare insert statement")
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close statement", map[string]interface{}{"error": closeErr.Error()})
}
}()
// Insert each user-question mapping
for _, userID := range userIDs {
_, err = stmt.ExecContext(ctx, userID, questionID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to assign user %d to question %d", userID, questionID)
}
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit transaction")
}
return nil
}
// UnassignUsersFromQuestion removes multiple users from a question
func (s *QuestionService) UnassignUsersFromQuestion(ctx context.Context, questionID int, userIDs []int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "unassign_users_from_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Start a transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Prepare the delete statement
stmt, err := tx.PrepareContext(ctx, `
DELETE FROM user_questions
WHERE user_id = $1 AND question_id = $2
`)
if err != nil {
return contextutils.WrapError(err, "failed to prepare delete statement")
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close statement", map[string]interface{}{"error": closeErr.Error()})
}
}()
// Delete each user-question mapping
for _, userID := range userIDs {
_, err = stmt.ExecContext(ctx, userID, questionID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to unassign user %d from question %d", userID, questionID)
}
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit transaction")
}
return nil
}
// DB returns the underlying *sql.DB instance
func (s *QuestionService) DB() *sql.DB {
return s.db
}
package services
import (
"context"
"database/sql"
"fmt"
"strings"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/serviceinterfaces"
contextutils "quizapp/internal/utils"
"github.com/lib/pq"
"go.opentelemetry.io/otel/attribute"
)
// SnippetsServiceInterface defines the interface for snippets services
type SnippetsServiceInterface = serviceinterfaces.SnippetsService
// SnippetsService handles snippets related business logic
type SnippetsService struct {
db *sql.DB
cfg *config.Config
logger *observability.Logger
}
// NewSnippetsService creates a new SnippetsService instance
func NewSnippetsService(db *sql.DB, cfg *config.Config, logger *observability.Logger) *SnippetsService {
return &SnippetsService{
db: db,
cfg: cfg,
logger: logger,
}
}
// getDefaultDifficultyLevel returns a sensible default difficulty level when no question context is available
func (s *SnippetsService) getDefaultDifficultyLevel() string {
// Default to "Unknown" when no question context is available
// Users can always update this through the UI if needed
return "Unknown"
}
// getQuestionLevel retrieves the difficulty level of a specific question
func (s *SnippetsService) getQuestionLevel(ctx context.Context, questionID int64) (result string, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_level",
observability.AttributeQuestionID(int(questionID)),
)
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return "", contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
query := `SELECT level FROM questions WHERE id = $1`
err = s.db.QueryRowContext(ctx, query, questionID).Scan(&result)
if err != nil {
if err == sql.ErrNoRows {
return "", contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with id %d not found", questionID)
}
return "", contextutils.WrapErrorf(err, "failed to get question level for question %d", questionID)
}
return result, nil
}
// CreateSnippet creates a new vocabulary snippet
func (s *SnippetsService) CreateSnippet(ctx context.Context, userID int64, req api.CreateSnippetRequest) (result *models.Snippet, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "create_snippet")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
// Check if snippet already exists for this user and text combination
exists, err := s.snippetExists(ctx, userID, req.OriginalText, req.SourceLanguage, req.TargetLanguage)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to check snippet existence")
}
if exists {
return nil, contextutils.WrapError(contextutils.ErrRecordExists, "snippet already exists for this user and text combination")
}
// Determine difficulty level - use question's level if question_id is provided, or section's level if section_id is provided
var difficultyLevel string
var levelSource string
if req.QuestionId != nil {
// Get the question's difficulty level
questionLevel, err := s.getQuestionLevel(ctx, *req.QuestionId)
if err != nil {
// If we can't get the question level, use default
s.logger.Warn(ctx, "Failed to get question level, using default",
map[string]any{"question_id": *req.QuestionId, "error": err.Error()})
difficultyLevel = s.getDefaultDifficultyLevel()
levelSource = "default_fallback"
} else {
difficultyLevel = questionLevel
levelSource = "question"
}
} else if req.SectionId != nil {
// Get the story section's language level
sectionLevel, err := s.getSectionLevel(ctx, *req.SectionId)
if err != nil {
// If we can't get the section level, use default
s.logger.Warn(ctx, "Failed to get section level, using default",
map[string]any{"section_id": *req.SectionId, "error": err.Error()})
difficultyLevel = s.getDefaultDifficultyLevel()
levelSource = "default_fallback"
} else {
difficultyLevel = sectionLevel
levelSource = "section"
}
} else {
// No question or section context, use default
difficultyLevel = s.getDefaultDifficultyLevel()
levelSource = "default"
}
span.SetAttributes(observability.AttributeLevel(difficultyLevel))
// Insert new snippet
query := `
INSERT INTO snippets (user_id, original_text, translated_text, source_language, target_language, question_id, section_id, story_id, context, difficulty_level)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id, created_at, updated_at`
result = &models.Snippet{}
err = s.db.QueryRowContext(ctx, query,
userID,
req.OriginalText,
req.TranslatedText,
req.SourceLanguage,
req.TargetLanguage,
req.QuestionId,
req.SectionId,
req.StoryId,
req.Context,
difficultyLevel,
).Scan(&result.ID, &result.CreatedAt, &result.UpdatedAt)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to create snippet")
}
// Set the remaining fields
result.UserID = userID
result.OriginalText = req.OriginalText
result.TranslatedText = req.TranslatedText
result.SourceLanguage = req.SourceLanguage
result.TargetLanguage = req.TargetLanguage
result.QuestionID = req.QuestionId
result.SectionID = req.SectionId
result.StoryID = req.StoryId
result.Context = req.Context
result.DifficultyLevel = &difficultyLevel
s.logger.Info(ctx, "Created new snippet",
map[string]any{
"snippet_id": result.ID,
"user_id": userID,
"original_text": req.OriginalText,
"source_language": req.SourceLanguage,
"difficulty_level": difficultyLevel,
"level_source": levelSource,
"question_id": req.QuestionId,
})
return result, nil
}
// getSectionLevel retrieves the language level of a specific story section
func (s *SnippetsService) getSectionLevel(ctx context.Context, sectionID int64) (result string, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "get_section_level")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return "", contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
query := `SELECT language_level FROM story_sections WHERE id = $1`
err = s.db.QueryRowContext(ctx, query, sectionID).Scan(&result)
if err != nil {
if err == sql.ErrNoRows {
return "", contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "story section with id %d not found", sectionID)
}
return "", contextutils.WrapErrorf(err, "failed to get section level for section %d", sectionID)
}
return result, nil
}
// GetSnippets retrieves snippets for a user with optional filtering
func (s *SnippetsService) GetSnippets(ctx context.Context, userID int64, params api.GetV1SnippetsParams) (result *api.SnippetList, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "get_snippets")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
query := `
SELECT id, user_id, original_text, translated_text, source_language, target_language,
question_id, section_id, story_id, context, difficulty_level, created_at, updated_at
FROM snippets
WHERE user_id = $1`
args := []any{userID}
argCount := 1
// Add search filter if provided
if params.Q != nil && *params.Q != "" {
argCount++
query += fmt.Sprintf(" AND (original_text ILIKE $%d OR translated_text ILIKE $%d)", argCount, argCount)
searchTerm := "%" + *params.Q + "%"
args = append(args, searchTerm)
}
// Add source language filter if provided
if params.SourceLang != nil && *params.SourceLang != "" {
argCount++
query += fmt.Sprintf(" AND source_language = $%d", argCount)
args = append(args, *params.SourceLang)
}
// Add target language filter if provided
if params.TargetLang != nil && *params.TargetLang != "" {
argCount++
query += fmt.Sprintf(" AND target_language = $%d", argCount)
args = append(args, *params.TargetLang)
}
// Add story_id filter if provided
if params.StoryId != nil && *params.StoryId > 0 {
argCount++
query += fmt.Sprintf(" AND story_id = $%d", argCount)
args = append(args, *params.StoryId)
}
// Add difficulty level filter if provided
if params.Level != nil && *params.Level != "" {
argCount++
query += fmt.Sprintf(" AND difficulty_level = $%d", argCount)
args = append(args, string(*params.Level))
}
// Add ordering and pagination
query += " ORDER BY created_at DESC"
if params.Limit != nil && *params.Limit > 0 {
argCount++
query += fmt.Sprintf(" LIMIT $%d", argCount)
limit := *params.Limit
if limit > 100 { // Max limit
limit = 100
}
args = append(args, limit)
}
if params.Offset != nil && *params.Offset > 0 {
argCount++
query += fmt.Sprintf(" OFFSET $%d", argCount)
args = append(args, *params.Offset)
}
// Execute query
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to query snippets")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]any{"error": closeErr.Error()})
}
}()
snippets := []api.Snippet{}
for rows.Next() {
var snippet models.Snippet
err := rows.Scan(
&snippet.ID,
&snippet.UserID,
&snippet.OriginalText,
&snippet.TranslatedText,
&snippet.SourceLanguage,
&snippet.TargetLanguage,
&snippet.QuestionID,
&snippet.SectionID,
&snippet.StoryID,
&snippet.Context,
&snippet.DifficultyLevel,
&snippet.CreatedAt,
&snippet.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan snippet")
}
snippets = append(snippets, api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
SectionId: snippet.SectionID,
StoryId: snippet.StoryID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
})
}
// Get total count for pagination info
totalQuery := "SELECT COUNT(*) FROM snippets WHERE user_id = $1"
totalArgs := []interface{}{userID}
// Apply the same filters for total count
if params.Q != nil && *params.Q != "" {
totalQuery += " AND (original_text ILIKE $2 OR translated_text ILIKE $2)"
totalArgs = append(totalArgs, "%"+*params.Q+"%")
}
if params.SourceLang != nil && *params.SourceLang != "" {
totalQuery += fmt.Sprintf(" AND source_language = $%d", len(totalArgs)+1)
totalArgs = append(totalArgs, *params.SourceLang)
}
if params.TargetLang != nil && *params.TargetLang != "" {
totalQuery += fmt.Sprintf(" AND target_language = $%d", len(totalArgs)+1)
totalArgs = append(totalArgs, *params.TargetLang)
}
if params.StoryId != nil && *params.StoryId > 0 {
totalQuery += fmt.Sprintf(" AND story_id = $%d", len(totalArgs)+1)
totalArgs = append(totalArgs, *params.StoryId)
}
if params.Level != nil && *params.Level != "" {
totalQuery += fmt.Sprintf(" AND difficulty_level = $%d", len(totalArgs)+1)
totalArgs = append(totalArgs, string(*params.Level))
}
var total int
err = s.db.QueryRowContext(ctx, totalQuery, totalArgs...).Scan(&total)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get total count")
}
// Build response
limit := 50 // default
offset := 0 // default
if params.Limit != nil {
limit = *params.Limit
}
if params.Offset != nil {
offset = *params.Offset
}
result = &api.SnippetList{
Snippets: &snippets,
Total: &total,
Limit: &limit,
Offset: &offset,
Query: params.Q,
}
return result, nil
}
// GetSnippetsByQuestion retrieves snippets for a user filtered by question ID
// This method is optimized for performance to support async loading in the UI
func (s *SnippetsService) GetSnippetsByQuestion(ctx context.Context, userID, questionID int64) (result []api.Snippet, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "get_snippets_by_question")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(
observability.AttributeUserID(int(userID)),
observability.AttributeQuestionID(int(questionID)),
)
// Query snippets for this user and question
// Uses the existing idx_snippets_question_id index for performance
query := `
SELECT id, user_id, original_text, translated_text, source_language, target_language,
question_id, context, difficulty_level, created_at, updated_at
FROM snippets
WHERE user_id = $1 AND question_id = $2
ORDER BY created_at DESC`
rows, err := s.db.QueryContext(ctx, query, userID, questionID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get snippets by question")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]any{"error": closeErr.Error()})
}
}()
snippets := []api.Snippet{}
for rows.Next() {
var snippet models.Snippet
err := rows.Scan(
&snippet.ID,
&snippet.UserID,
&snippet.OriginalText,
&snippet.TranslatedText,
&snippet.SourceLanguage,
&snippet.TargetLanguage,
&snippet.QuestionID,
&snippet.Context,
&snippet.DifficultyLevel,
&snippet.CreatedAt,
&snippet.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan snippet")
}
snippets = append(snippets, api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
})
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapErrorf(err, "error iterating over snippet rows")
}
return snippets, nil
}
// GetSnippetsBySection retrieves snippets for a user filtered by section ID
// This method is optimized for performance to support async loading in the UI
func (s *SnippetsService) GetSnippetsBySection(ctx context.Context, userID, sectionID int64) (result []api.Snippet, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "get_snippets_by_section")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(
observability.AttributeUserID(int(userID)),
attribute.Int64("section.id", sectionID),
)
// Query snippets for this user and section
// Uses the new idx_snippets_section_id index for performance
query := `
SELECT id, user_id, original_text, translated_text, source_language, target_language,
question_id, section_id, story_id, context, difficulty_level, created_at, updated_at
FROM snippets
WHERE user_id = $1 AND section_id = $2
ORDER BY created_at DESC`
rows, err := s.db.QueryContext(ctx, query, userID, sectionID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get snippets by section")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]any{"error": closeErr.Error()})
}
}()
snippets := []api.Snippet{}
for rows.Next() {
var snippet models.Snippet
err := rows.Scan(
&snippet.ID,
&snippet.UserID,
&snippet.OriginalText,
&snippet.TranslatedText,
&snippet.SourceLanguage,
&snippet.TargetLanguage,
&snippet.QuestionID,
&snippet.SectionID,
&snippet.StoryID,
&snippet.Context,
&snippet.DifficultyLevel,
&snippet.CreatedAt,
&snippet.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan snippet")
}
snippets = append(snippets, api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
SectionId: snippet.SectionID,
StoryId: snippet.StoryID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
})
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapErrorf(err, "error iterating over snippet rows")
}
return snippets, nil
}
// GetSnippetsByStory retrieves snippets for a user filtered by story ID
// This method is optimized for performance to support async loading in the UI
func (s *SnippetsService) GetSnippetsByStory(ctx context.Context, userID, storyID int64) (result []api.Snippet, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "get_snippets_by_story")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(
observability.AttributeUserID(int(userID)),
attribute.Int64("story.id", storyID),
)
// Query snippets for this user and story
// Uses the new idx_snippets_story_id index for performance
query := `
SELECT id, user_id, original_text, translated_text, source_language, target_language,
question_id, section_id, story_id, context, difficulty_level, created_at, updated_at
FROM snippets
WHERE user_id = $1 AND story_id = $2
ORDER BY created_at DESC`
rows, err := s.db.QueryContext(ctx, query, userID, storyID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get snippets by story")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]any{"error": closeErr.Error()})
}
}()
snippets := []api.Snippet{}
for rows.Next() {
var snippet models.Snippet
err := rows.Scan(
&snippet.ID,
&snippet.UserID,
&snippet.OriginalText,
&snippet.TranslatedText,
&snippet.SourceLanguage,
&snippet.TargetLanguage,
&snippet.QuestionID,
&snippet.SectionID,
&snippet.StoryID,
&snippet.Context,
&snippet.DifficultyLevel,
&snippet.CreatedAt,
&snippet.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan snippet")
}
snippets = append(snippets, api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
SectionId: snippet.SectionID,
StoryId: snippet.StoryID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
})
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapErrorf(err, "error iterating over snippet rows")
}
return snippets, nil
}
// SearchSnippets searches across all snippets for a user
func (s *SnippetsService) SearchSnippets(ctx context.Context, userID int64, query string, limit, offset int, sourceLang *string) (result []api.Snippet, totalCount int, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "search_snippets")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, 0, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
// Clean and prepare the search query
searchQuery := strings.TrimSpace(query)
if searchQuery == "" {
return nil, 0, contextutils.WrapError(contextutils.ErrInvalidInput, "search query cannot be empty")
}
// Search in both original_text and translated_text
searchTerm := fmt.Sprintf("%%%s%%", strings.ToLower(searchQuery))
// Get total count of matching snippets
totalQuery := `
SELECT COUNT(*)
FROM snippets
WHERE user_id = $1 AND (LOWER(original_text) LIKE $2 OR LOWER(translated_text) LIKE $3)`
var total int
// Add optional source language filter
totalArgs := []any{userID, searchTerm, searchTerm}
if sourceLang != nil && *sourceLang != "" {
totalQuery += " AND source_language = $4"
totalArgs = append(totalArgs, *sourceLang)
}
err = s.db.QueryRowContext(ctx, totalQuery, totalArgs...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapErrorf(err, "failed to get total count for search")
}
// Get matching snippets
queryStr := `
SELECT id, user_id, original_text, translated_text, source_language, target_language,
question_id, section_id, story_id, context, difficulty_level, created_at, updated_at
FROM snippets
WHERE user_id = $1 AND (LOWER(original_text) LIKE $2 OR LOWER(translated_text) LIKE $3)`
args := []any{userID, searchTerm, searchTerm}
if sourceLang != nil && *sourceLang != "" {
queryStr += " AND source_language = $4"
args = append(args, *sourceLang)
queryStr += " ORDER BY created_at DESC LIMIT $5 OFFSET $6"
args = append(args, limit, offset)
} else {
queryStr += " ORDER BY created_at DESC LIMIT $4 OFFSET $5"
args = append(args, limit, offset)
}
rows, err := s.db.QueryContext(ctx, queryStr, args...)
if err != nil {
return nil, 0, contextutils.WrapErrorf(err, "failed to search snippets")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]any{"error": closeErr.Error()})
}
}()
snippets := []api.Snippet{}
for rows.Next() {
var snippet models.Snippet
err := rows.Scan(
&snippet.ID,
&snippet.UserID,
&snippet.OriginalText,
&snippet.TranslatedText,
&snippet.SourceLanguage,
&snippet.TargetLanguage,
&snippet.QuestionID,
&snippet.SectionID,
&snippet.StoryID,
&snippet.Context,
&snippet.DifficultyLevel,
&snippet.CreatedAt,
&snippet.UpdatedAt,
)
if err != nil {
return nil, 0, contextutils.WrapErrorf(err, "failed to scan snippet")
}
snippets = append(snippets, api.Snippet{
Id: &snippet.ID,
UserId: &snippet.UserID,
OriginalText: &snippet.OriginalText,
TranslatedText: &snippet.TranslatedText,
SourceLanguage: &snippet.SourceLanguage,
TargetLanguage: &snippet.TargetLanguage,
QuestionId: snippet.QuestionID,
SectionId: snippet.SectionID,
StoryId: snippet.StoryID,
Context: snippet.Context,
DifficultyLevel: snippet.DifficultyLevel,
CreatedAt: &snippet.CreatedAt,
UpdatedAt: &snippet.UpdatedAt,
})
}
return snippets, total, nil
}
// snippetExists checks if a snippet already exists for the user
func (s *SnippetsService) snippetExists(ctx context.Context, userID int64, originalText, sourceLanguage, targetLanguage string) (bool, error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "snippet_exists")
defer observability.FinishSpan(span, nil)
// Check if database connection is valid
if s.db == nil {
return false, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
query := `
SELECT COUNT(*)
FROM snippets
WHERE user_id = $1 AND original_text = $2 AND source_language = $3 AND target_language = $4`
var count int
err := s.db.QueryRowContext(ctx, query, userID, originalText, sourceLanguage, targetLanguage).Scan(&count)
if err != nil {
return false, contextutils.WrapErrorf(err, "failed to check snippet existence")
}
return count > 0, nil
}
// GetSnippet retrieves a specific snippet by ID
func (s *SnippetsService) GetSnippet(ctx context.Context, userID, snippetID int64) (result *models.Snippet, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "get_snippet")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
span.SetAttributes(observability.AttributeSnippetID(int(snippetID)))
query := `
SELECT id, user_id, original_text, translated_text, source_language, target_language,
question_id, context, difficulty_level, created_at, updated_at
FROM snippets
WHERE id = $1 AND user_id = $2`
result = &models.Snippet{}
err = s.db.QueryRowContext(ctx, query, snippetID, userID).Scan(
&result.ID,
&result.UserID,
&result.OriginalText,
&result.TranslatedText,
&result.SourceLanguage,
&result.TargetLanguage,
&result.QuestionID,
&result.Context,
&result.DifficultyLevel,
&result.CreatedAt,
&result.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "snippet not found")
}
return nil, contextutils.WrapErrorf(err, "failed to get snippet")
}
return result, nil
}
// UpdateSnippet updates a snippet's fields
func (s *SnippetsService) UpdateSnippet(ctx context.Context, userID, snippetID int64, req api.UpdateSnippetRequest) (result *models.Snippet, err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "update_snippet")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
span.SetAttributes(observability.AttributeSnippetID(int(snippetID)))
// Build dynamic query based on which fields are provided
setParts := []string{"updated_at = CURRENT_TIMESTAMP"}
args := []interface{}{}
argCount := 0
if req.OriginalText != nil {
argCount++
setParts = append(setParts, fmt.Sprintf("original_text = $%d", argCount))
args = append(args, *req.OriginalText)
}
if req.TranslatedText != nil {
argCount++
setParts = append(setParts, fmt.Sprintf("translated_text = $%d", argCount))
args = append(args, *req.TranslatedText)
}
if req.SourceLanguage != nil {
argCount++
setParts = append(setParts, fmt.Sprintf("source_language = $%d", argCount))
args = append(args, *req.SourceLanguage)
}
if req.TargetLanguage != nil {
argCount++
setParts = append(setParts, fmt.Sprintf("target_language = $%d", argCount))
args = append(args, *req.TargetLanguage)
}
if req.Context != nil {
argCount++
setParts = append(setParts, fmt.Sprintf("context = $%d", argCount))
args = append(args, *req.Context)
}
if len(setParts) == 1 {
// No fields to update
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "no fields to update")
}
// Add WHERE clause parameters
argCount++
whereClause := fmt.Sprintf("WHERE id = $%d AND user_id = $%d", argCount, argCount+1)
args = append(args, snippetID, userID)
query := fmt.Sprintf(`
UPDATE snippets
SET %s
%s
RETURNING id, user_id, original_text, translated_text, source_language, target_language,
question_id, context, difficulty_level, created_at, updated_at`,
strings.Join(setParts, ", "), whereClause)
result = &models.Snippet{}
err = s.db.QueryRowContext(ctx, query, args...).Scan(
&result.ID,
&result.UserID,
&result.OriginalText,
&result.TranslatedText,
&result.SourceLanguage,
&result.TargetLanguage,
&result.QuestionID,
&result.Context,
&result.DifficultyLevel,
&result.CreatedAt,
&result.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "snippet not found")
}
// Map unique constraint violations to conflict error (409)
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "23505" {
return nil, contextutils.WrapError(contextutils.ErrRecordExists, "a snippet with the same text and language already exists in this context")
}
return nil, contextutils.WrapErrorf(err, "failed to update snippet")
}
s.logger.Info(ctx, "Updated snippet",
map[string]any{
"snippet_id": result.ID,
"user_id": userID,
})
return result, nil
}
// DeleteSnippet deletes a snippet
func (s *SnippetsService) DeleteSnippet(ctx context.Context, userID, snippetID int64) (err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "delete_snippet")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
span.SetAttributes(observability.AttributeSnippetID(int(snippetID)))
result, err := s.db.ExecContext(ctx, "DELETE FROM snippets WHERE id = $1 AND user_id = $2", snippetID, userID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to delete snippet")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapErrorf(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "snippet not found")
}
s.logger.Info(ctx, "Deleted snippet",
map[string]any{
"snippet_id": snippetID,
"user_id": userID,
})
return nil
}
// DeleteAllSnippets deletes all snippets for a user
func (s *SnippetsService) DeleteAllSnippets(ctx context.Context, userID int64) (err error) {
ctx, span := observability.TraceFunction(ctx, "snippets", "delete_all_snippets")
defer observability.FinishSpan(span, &err)
// Check if database connection is valid
if s.db == nil {
return contextutils.WrapError(contextutils.ErrInternalError, "database connection is nil")
}
span.SetAttributes(observability.AttributeUserID(int(userID)))
result, err := s.db.ExecContext(ctx, "DELETE FROM snippets WHERE user_id = $1", userID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to delete all snippets for user")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapErrorf(err, "failed to get rows affected")
}
s.logger.Info(ctx, "Deleted all snippets for user",
map[string]any{
"user_id": userID,
"snippets_deleted": rowsAffected,
})
return nil
}
package services
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
)
// StoryServiceInterface defines the interface for story operations
type StoryServiceInterface interface {
CreateStory(ctx context.Context, userID uint, language string, req *models.CreateStoryRequest) (*models.Story, error)
GetUserStories(ctx context.Context, userID uint, includeArchived bool) ([]models.Story, error)
GetCurrentStory(ctx context.Context, userID uint) (*models.StoryWithSections, error)
GetStory(ctx context.Context, storyID, userID uint) (*models.StoryWithSections, error)
ArchiveStory(ctx context.Context, storyID, userID uint) error
CompleteStory(ctx context.Context, storyID, userID uint) error
SetCurrentStory(ctx context.Context, storyID, userID uint) error
ToggleAutoGeneration(ctx context.Context, storyID, userID uint, paused bool) error
DeleteStory(ctx context.Context, storyID, userID uint) error
DeleteAllStoriesForUser(ctx context.Context, userID uint) error
FixCurrentStoryConstraint(ctx context.Context) error
GetStorySections(ctx context.Context, storyID uint) ([]models.StorySection, error)
GetSection(ctx context.Context, sectionID, userID uint) (*models.StorySectionWithQuestions, error)
CreateSection(ctx context.Context, storyID uint, content, level string, wordCount int, generatedBy models.GeneratorType) (*models.StorySection, error)
GetLatestSection(ctx context.Context, storyID uint) (*models.StorySection, error)
GetAllSectionsText(ctx context.Context, storyID uint) (string, error)
GetSectionQuestions(ctx context.Context, sectionID uint) ([]models.StorySectionQuestion, error)
CreateSectionQuestions(ctx context.Context, sectionID uint, questions []models.StorySectionQuestionData) error
GetRandomQuestions(ctx context.Context, sectionID uint, count int) ([]models.StorySectionQuestion, error)
UpdateLastGenerationTime(ctx context.Context, storyID uint, generatorType models.GeneratorType) error
RecordStorySectionView(ctx context.Context, userID, sectionID uint) error
HasUserViewedLatestSection(ctx context.Context, userID uint) (bool, error)
GetSectionLengthTarget(level string, lengthPref *models.SectionLength) int
GetSectionLengthTargetWithLanguage(language, level string, lengthPref *models.SectionLength) int
SanitizeInput(input string) string
GenerateStorySection(ctx context.Context, storyID, userID uint, aiService AIServiceInterface, userAIConfig *models.UserAIConfig, generatorType models.GeneratorType) (*models.StorySectionWithQuestions, error)
// Admin-only helpers (no ownership checks)
GetStoriesPaginated(ctx context.Context, page, pageSize int, search, language, status string, userID *uint) ([]models.Story, int, error)
GetStoryAdmin(ctx context.Context, storyID uint) (*models.StoryWithSections, error)
GetSectionAdmin(ctx context.Context, sectionID uint) (*models.StorySectionWithQuestions, error)
// Admin-only delete without ownership check
DeleteStoryAdmin(ctx context.Context, storyID uint) error
}
// StoryService handles all story-related operations
type StoryService struct {
db *sql.DB
config *config.Config
logger *observability.Logger
}
// NewStoryService creates a new StoryService instance
func NewStoryService(db *sql.DB, config *config.Config, logger *observability.Logger) *StoryService {
if db == nil {
panic("StoryService requires a valid database connection")
}
return &StoryService{
db: db,
config: config,
logger: logger,
}
}
// CreateStory creates a new story for the user
func (s *StoryService) CreateStory(ctx context.Context, userID uint, language string, req *models.CreateStoryRequest) (*models.Story, error) {
if err := req.Validate(); err != nil {
return nil, contextutils.WrapErrorf(err, "invalid story request")
}
// Check if user has reached the archived story limit
archivedCount, err := s.getArchivedStoryCount(ctx, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to check archived story count")
}
if archivedCount >= s.config.Story.MaxArchivedPerUser {
return nil, contextutils.ErrorWithContextf("maximum archived stories limit reached (%d)", s.config.Story.MaxArchivedPerUser)
}
// Get user's current language level (stored for potential future use)
_, err = s.getUserCurrentLevel(ctx, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get user level")
}
// Unset any existing active story in the same language first
unsetQuery := "UPDATE stories SET status = $1, updated_at = NOW() WHERE user_id = $2 AND language = $3 AND status = $4"
_, err = s.db.ExecContext(ctx, unsetQuery, models.StoryStatusArchived, userID, language, models.StoryStatusActive)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to unset existing current story")
}
// Create the story
story := &models.Story{
UserID: userID,
Title: req.Title,
Language: language,
Subject: req.Subject,
AuthorStyle: req.AuthorStyle,
TimePeriod: req.TimePeriod,
Genre: req.Genre,
Tone: req.Tone,
CharacterNames: req.CharacterNames,
CustomInstructions: req.CustomInstructions,
SectionLengthOverride: req.SectionLengthOverride,
Status: models.StoryStatusActive,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.createStory(ctx, story); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to create story")
}
s.logger.Info(context.Background(), "Story created successfully",
map[string]interface{}{
"story_id": story.ID,
"user_id": userID,
"title": story.Title,
})
return story, nil
}
// GetUserStories retrieves all stories for a user in their preferred language
func (s *StoryService) GetUserStories(ctx context.Context, userID uint, includeArchived bool) ([]models.Story, error) {
// Get user's preferred language
user, err := s.getUserByID(ctx, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get user")
}
if user == nil {
// Return empty slice for non-existent user instead of error
return []models.Story{}, nil
}
language := "en" // default
if user.PreferredLanguage.Valid {
language = user.PreferredLanguage.String
}
query := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
auto_generation_paused, last_section_generated_at, created_at, updated_at
FROM stories
WHERE user_id = $1 AND language = $2`
args := []interface{}{userID, language}
if !includeArchived {
query += " AND status != $3"
args = append(args, models.StoryStatusArchived)
}
query += " ORDER BY status = 'active' DESC, created_at DESC"
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
stories := []models.Story{}
for rows.Next() {
var story models.Story
err := rows.Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.AutoGenerationPaused,
&story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
)
if err != nil {
return nil, err
}
stories = append(stories, story)
}
return stories, rows.Err()
}
// GetCurrentStory retrieves the current active story for a user in their current language
func (s *StoryService) GetCurrentStory(ctx context.Context, userID uint) (*models.StoryWithSections, error) {
// Get user's current language preference
user, err := s.getUserByID(ctx, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get user")
}
if user == nil {
return nil, contextutils.ErrorWithContextf("user not found: %d", userID)
}
language := "en" // default
if user.PreferredLanguage.Valid {
language = user.PreferredLanguage.String
}
query := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
auto_generation_paused, last_section_generated_at, created_at, updated_at
FROM stories
WHERE user_id = $1 AND language = $2 AND status = $3 AND status != $4`
var story models.Story
err = s.db.QueryRowContext(ctx, query, userID, language, models.StoryStatusActive, models.StoryStatusArchived).Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.AutoGenerationPaused,
&story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // No current story in user's language
}
return nil, contextutils.WrapErrorf(err, "failed to get current story")
}
// Load sections
sections, err := s.GetStorySections(ctx, story.ID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to load story sections")
}
return &models.StoryWithSections{
Story: story,
Sections: sections,
}, nil
}
// GetStory retrieves a specific story with ownership verification
func (s *StoryService) GetStory(ctx context.Context, storyID, userID uint) (*models.StoryWithSections, error) {
query := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
auto_generation_paused, last_section_generated_at, created_at, updated_at
FROM stories
WHERE id = $1 AND user_id = $2`
var story models.Story
err := s.db.QueryRowContext(ctx, query, storyID, userID).Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.AutoGenerationPaused,
&story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, contextutils.ErrorWithContextf("story not found or access denied")
}
return nil, contextutils.WrapErrorf(err, "failed to get story")
}
// Load sections
sections, err := s.GetStorySections(ctx, story.ID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to load story sections")
}
return &models.StoryWithSections{
Story: story,
Sections: sections,
}, nil
}
// ArchiveStory changes a story's status to archived
func (s *StoryService) ArchiveStory(ctx context.Context, storyID, userID uint) error {
if err := s.validateStoryOwnership(ctx, storyID, userID); err != nil {
return err
}
// First, check if the story is completed (completed stories cannot be archived)
var status string
checkQuery := "SELECT status FROM stories WHERE id = $1"
err := s.db.QueryRowContext(ctx, checkQuery, storyID).Scan(&status)
if err != nil {
return contextutils.WrapErrorf(err, "failed to check story status")
}
// Prevent archiving completed stories
if status == string(models.StoryStatusCompleted) {
return contextutils.ErrorWithContextf("cannot archive completed stories")
}
// Archive the story (this automatically removes it from being current since only active stories are current)
query := "UPDATE stories SET status = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, query, models.StoryStatusArchived, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to archive story")
}
s.logger.Info(context.Background(), "Story archived successfully",
map[string]interface{}{
"story_id": storyID,
"user_id": userID,
})
return nil
}
// CompleteStory changes a story's status to completed
func (s *StoryService) CompleteStory(ctx context.Context, storyID, userID uint) error {
if err := s.validateStoryOwnership(ctx, storyID, userID); err != nil {
return err
}
query := "UPDATE stories SET status = $1, updated_at = NOW() WHERE id = $2"
_, err := s.db.ExecContext(ctx, query, models.StoryStatusCompleted, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to complete story")
}
s.logger.Info(context.Background(), "Story completed successfully",
map[string]interface{}{
"story_id": storyID,
"user_id": userID,
})
return nil
}
// ToggleAutoGeneration toggles the auto-generation pause state for a story
func (s *StoryService) ToggleAutoGeneration(ctx context.Context, storyID, userID uint, paused bool) error {
if err := s.validateStoryOwnership(ctx, storyID, userID); err != nil {
return err
}
query := "UPDATE stories SET auto_generation_paused = $1, updated_at = NOW() WHERE id = $2"
_, err := s.db.ExecContext(ctx, query, paused, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to toggle auto-generation")
}
s.logger.Info(context.Background(), "Story auto-generation toggled",
map[string]interface{}{
"story_id": storyID,
"user_id": userID,
"paused": paused,
})
return nil
}
// SetCurrentStory sets a story as the current active story for the user in its language
func (s *StoryService) SetCurrentStory(ctx context.Context, storyID, userID uint) error {
if err := s.validateStoryOwnership(ctx, storyID, userID); err != nil {
return err
}
// Get the story's language and status
query := "SELECT language, status FROM stories WHERE id = $1 AND user_id = $2"
var language string
var storyStatus models.StoryStatus
err := s.db.QueryRowContext(ctx, query, storyID, userID).Scan(&language, &storyStatus)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("story not found or access denied")
}
return contextutils.WrapErrorf(err, "failed to get story language and status")
}
// Only allow restoring active stories (not completed ones)
if storyStatus == models.StoryStatusCompleted {
return contextutils.ErrorWithContextf("cannot restore completed stories")
}
// Get the user's preferred language
user, err := s.getUserByID(ctx, userID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to get user")
}
if user == nil {
return contextutils.ErrorWithContextf("user not found")
}
userPreferredLanguage := "en" // default
if user.PreferredLanguage.Valid {
userPreferredLanguage = user.PreferredLanguage.String
}
// Check if the story's language matches the user's preferred language
if language != userPreferredLanguage {
return contextutils.ErrorWithContextf("cannot restore story in different language than preferred language")
}
// Archive any existing active story in the same language for this user
// (since only one story can be active per user per language)
unsetQuery := "UPDATE stories SET status = $1, updated_at = NOW() WHERE user_id = $2 AND language = $3 AND status = $4"
_, err = s.db.ExecContext(ctx, unsetQuery, models.StoryStatusArchived, userID, language, models.StoryStatusActive)
if err != nil {
return contextutils.WrapErrorf(err, "failed to unset existing active story")
}
// Set the specified story as active (which makes it current)
setQuery := "UPDATE stories SET status = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, setQuery, models.StoryStatusActive, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to set current story")
}
return nil
}
// FixCurrentStoryConstraint fixes any constraint violations where multiple stories are marked as active for the same user in the same language
func (s *StoryService) FixCurrentStoryConstraint(ctx context.Context) error {
// Find all users who have multiple active stories in the same language
query := `
SELECT user_id, language, COUNT(*) as active_count
FROM stories
WHERE status = 'active'
GROUP BY user_id, language
HAVING COUNT(*) > 1`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return contextutils.WrapErrorf(err, "failed to find users with multiple active stories in same language")
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var userID uint
var language string
var activeCount int
if err := rows.Scan(&userID, &language, &activeCount); err != nil {
return contextutils.WrapErrorf(err, "failed to scan user constraint violation")
}
// Fix constraint violation for this user and language
if err := s.fixUserCurrentStoryConstraint(ctx, userID, language); err != nil {
return contextutils.WrapErrorf(err, "failed to fix constraint for user %d in language %s", userID, language)
}
}
return rows.Err()
}
// fixUserCurrentStoryConstraint fixes constraint violations for a specific user in a specific language
func (s *StoryService) fixUserCurrentStoryConstraint(ctx context.Context, userID uint, language string) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapErrorf(err, "failed to begin transaction")
}
defer func() { _ = tx.Rollback() }()
// Find all active stories for this user in this language, ordered by most recently updated
var activeStories []uint
selectQuery := `
SELECT id FROM stories
WHERE user_id = $1 AND language = $2 AND status = 'active'
ORDER BY updated_at DESC`
rows, err := tx.QueryContext(ctx, selectQuery, userID, language)
if err != nil {
return contextutils.WrapErrorf(err, "failed to find active stories for user in language")
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var storyID uint
if err := rows.Scan(&storyID); err != nil {
return contextutils.WrapErrorf(err, "failed to scan story ID")
}
activeStories = append(activeStories, storyID)
}
if len(activeStories) <= 1 {
// No constraint violation for this user in this language
return tx.Commit()
}
// Archive all active stories except the most recently updated one
for i := 1; i < len(activeStories); i++ {
unsetQuery := "UPDATE stories SET status = $1, updated_at = NOW() WHERE id = $2"
_, err = tx.ExecContext(ctx, unsetQuery, models.StoryStatusArchived, activeStories[i])
if err != nil {
return contextutils.WrapErrorf(err, "failed to unset active story %d", activeStories[i])
}
}
return tx.Commit()
}
// DeleteStory permanently deletes a story (only allowed for archived stories)
func (s *StoryService) DeleteStory(ctx context.Context, storyID, userID uint) error {
// Verify story exists and user owns it
query := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
last_section_generated_at, created_at, updated_at
FROM stories
WHERE id = $1 AND user_id = $2`
var story models.Story
err := s.db.QueryRowContext(ctx, query, storyID, userID).Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("story not found or access denied")
}
return contextutils.WrapErrorf(err, "failed to get story")
}
// Only allow deletion of archived or completed stories
if story.Status != models.StoryStatusArchived && story.Status != models.StoryStatusCompleted {
return contextutils.ErrorWithContextf("cannot delete active story")
}
// Use transaction for atomic deletion
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapErrorf(err, "failed to begin transaction")
}
defer func() { _ = tx.Rollback() }()
// Delete questions first (due to foreign key constraints)
query1 := "DELETE FROM story_section_questions WHERE section_id IN (SELECT id FROM story_sections WHERE story_id = $1)"
_, err = tx.ExecContext(ctx, query1, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to delete story questions")
}
// Delete sections
query2 := "DELETE FROM story_sections WHERE story_id = $1"
_, err = tx.ExecContext(ctx, query2, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to delete story sections")
}
// Delete story
query3 := "DELETE FROM stories WHERE id = $1"
_, err = tx.ExecContext(ctx, query3, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to delete story")
}
return tx.Commit()
}
// DeleteStoryAdmin permanently deletes a story by ID without ownership checks (admin-only).
// Admins can delete stories regardless of status, but regular users cannot delete active stories.
func (s *StoryService) DeleteStoryAdmin(ctx context.Context, storyID uint) error {
// Verify story exists
query := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
last_section_generated_at, created_at, updated_at
FROM stories
WHERE id = $1`
var story models.Story
if err := s.db.QueryRowContext(ctx, query, storyID).Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("story not found")
}
return contextutils.WrapErrorf(err, "failed to get story")
}
// Admin can delete any story regardless of status
// Use transaction for atomic deletion
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapErrorf(err, "failed to begin transaction")
}
defer func() { _ = tx.Rollback() }()
// Delete questions first (due to foreign key constraints)
if _, err := tx.ExecContext(ctx, "DELETE FROM story_section_questions WHERE section_id IN (SELECT id FROM story_sections WHERE story_id = $1)", storyID); err != nil {
return contextutils.WrapErrorf(err, "failed to delete story questions")
}
// Delete sections
if _, err := tx.ExecContext(ctx, "DELETE FROM story_sections WHERE story_id = $1", storyID); err != nil {
return contextutils.WrapErrorf(err, "failed to delete story sections")
}
// Delete story
if _, err := tx.ExecContext(ctx, "DELETE FROM stories WHERE id = $1", storyID); err != nil {
return contextutils.WrapErrorf(err, "failed to delete story")
}
return tx.Commit()
}
// DeleteAllStoriesForUser deletes all stories (and their sections/questions) for a given user
func (s *StoryService) DeleteAllStoriesForUser(ctx context.Context, userID uint) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapErrorf(err, "failed to begin transaction")
}
defer func() { _ = tx.Rollback() }()
// Delete questions for all sections belonging to stories of this user
q1 := `DELETE FROM story_section_questions WHERE section_id IN (SELECT id FROM story_sections WHERE story_id IN (SELECT id FROM stories WHERE user_id = $1))`
if _, err := tx.ExecContext(ctx, q1, userID); err != nil {
return contextutils.WrapErrorf(err, "failed to delete story questions for user %d", userID)
}
// Delete sections for all stories belonging to this user
q2 := `DELETE FROM story_sections WHERE story_id IN (SELECT id FROM stories WHERE user_id = $1)`
if _, err := tx.ExecContext(ctx, q2, userID); err != nil {
return contextutils.WrapErrorf(err, "failed to delete story sections for user %d", userID)
}
// Finally delete stories
q3 := `DELETE FROM stories WHERE user_id = $1`
if _, err := tx.ExecContext(ctx, q3, userID); err != nil {
return contextutils.WrapErrorf(err, "failed to delete stories for user %d", userID)
}
if err := tx.Commit(); err != nil {
return contextutils.WrapErrorf(err, "failed to commit delete all stories transaction for user %d", userID)
}
s.logger.Info(context.Background(), "Deleted all stories for user", map[string]interface{}{"user_id": userID})
return nil
}
// GetStorySections retrieves all sections for a story
func (s *StoryService) GetStorySections(ctx context.Context, storyID uint) ([]models.StorySection, error) {
query := `
SELECT id, story_id, section_number, content, language_level, word_count,
generated_by, generated_at, generation_date
FROM story_sections
WHERE story_id = $1
ORDER BY section_number ASC`
rows, err := s.db.QueryContext(ctx, query, storyID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get story sections")
}
defer func() { _ = rows.Close() }()
sections := make([]models.StorySection, 0)
for rows.Next() {
var section models.StorySection
err := rows.Scan(
§ion.ID, §ion.StoryID, §ion.SectionNumber, §ion.Content,
§ion.LanguageLevel, §ion.WordCount, §ion.GeneratedBy, §ion.GeneratedAt, §ion.GenerationDate,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan story section")
}
sections = append(sections, section)
}
return sections, rows.Err()
}
// GetSection retrieves a specific section with ownership verification
func (s *StoryService) GetSection(ctx context.Context, sectionID, userID uint) (*models.StorySectionWithQuestions, error) {
query := `
SELECT ss.id, ss.story_id, ss.section_number, ss.content, ss.language_level, ss.word_count,
ss.generated_by, ss.generated_at, ss.generation_date
FROM story_sections ss
JOIN stories s ON ss.story_id = s.id
WHERE ss.id = $1 AND s.user_id = $2`
var section models.StorySection
err := s.db.QueryRowContext(ctx, query, sectionID, userID).Scan(
§ion.ID, §ion.StoryID, §ion.SectionNumber, §ion.Content,
§ion.LanguageLevel, §ion.WordCount, §ion.GeneratedBy, §ion.GeneratedAt, §ion.GenerationDate,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, contextutils.ErrorWithContextf("section not found or access denied")
}
return nil, contextutils.WrapErrorf(err, "failed to get section")
}
// Load questions
questions, err := s.GetSectionQuestions(ctx, section.ID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to load section questions")
}
return &models.StorySectionWithQuestions{
StorySection: section,
Questions: questions,
}, nil
}
// CreateSection adds a new section to a story
func (s *StoryService) CreateSection(ctx context.Context, storyID uint, content, level string, wordCount int, generatedBy models.GeneratorType) (*models.StorySection, error) {
// Get the next section number
var maxSectionNumber int
query := "SELECT COALESCE(MAX(section_number), 0) FROM story_sections WHERE story_id = $1"
err := s.db.QueryRowContext(ctx, query, storyID).Scan(&maxSectionNumber)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get max section number")
}
section := &models.StorySection{
StoryID: storyID,
SectionNumber: maxSectionNumber + 1,
Content: content,
LanguageLevel: level,
WordCount: wordCount,
GeneratedBy: generatedBy,
GeneratedAt: time.Now(),
GenerationDate: time.Now().Truncate(24 * time.Hour),
}
if err := s.createSection(ctx, section); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to create section")
}
return section, nil
}
// GetLatestSection retrieves the most recent section for a story
func (s *StoryService) GetLatestSection(ctx context.Context, storyID uint) (*models.StorySection, error) {
query := `
SELECT id, story_id, section_number, content, language_level, word_count,
generated_by, generated_at, generation_date
FROM story_sections
WHERE story_id = $1
ORDER BY section_number DESC
LIMIT 1`
var section models.StorySection
err := s.db.QueryRowContext(ctx, query, storyID).Scan(
§ion.ID, §ion.StoryID, §ion.SectionNumber, §ion.Content,
§ion.LanguageLevel, §ion.WordCount, §ion.GeneratedBy, §ion.GeneratedAt, §ion.GenerationDate,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // No sections yet
}
return nil, contextutils.WrapErrorf(err, "failed to get latest section")
}
return §ion, nil
}
// GetAllSectionsText concatenates all section content for AI context
func (s *StoryService) GetAllSectionsText(ctx context.Context, storyID uint) (string, error) {
sections, err := s.GetStorySections(ctx, storyID)
if err != nil {
return "", err
}
var sectionsText strings.Builder
for i, section := range sections {
if i > 0 {
sectionsText.WriteString("\n\n")
}
sectionsText.WriteString(fmt.Sprintf("Section %d:\n%s", section.SectionNumber, section.Content))
}
return sectionsText.String(), nil
}
// GetSectionQuestions retrieves all questions for a section
func (s *StoryService) GetSectionQuestions(ctx context.Context, sectionID uint) ([]models.StorySectionQuestion, error) {
query := `
SELECT id, section_id, question_text, options, correct_answer_index, explanation, created_at
FROM story_section_questions
WHERE section_id = $1`
rows, err := s.db.QueryContext(ctx, query, sectionID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get section questions")
}
defer func() { _ = rows.Close() }()
questions := []models.StorySectionQuestion{}
for rows.Next() {
var question models.StorySectionQuestion
var optionsJSON []byte
err := rows.Scan(
&question.ID, &question.SectionID, &question.QuestionText, &optionsJSON,
&question.CorrectAnswerIndex, &question.Explanation, &question.CreatedAt,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan question")
}
// Unmarshal JSON options back to []string
err = json.Unmarshal(optionsJSON, &question.Options)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to unmarshal options from JSON")
}
questions = append(questions, question)
}
return questions, rows.Err()
}
// CreateSectionQuestions bulk inserts questions for a section
func (s *StoryService) CreateSectionQuestions(ctx context.Context, sectionID uint, questions []models.StorySectionQuestionData) error {
if len(questions) == 0 {
return nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapErrorf(err, "failed to begin transaction")
}
defer func() { _ = tx.Rollback() }()
for _, q := range questions {
query := `
INSERT INTO story_section_questions (
section_id, question_text, options, correct_answer_index, explanation, created_at
) VALUES ($1, $2, $3, $4, $5, $6)`
// Convert []string options to JSON for PostgreSQL JSONB column
optionsJSON, err := json.Marshal(q.Options)
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal options to JSON")
}
_, err = tx.ExecContext(ctx, query,
sectionID, q.QuestionText, optionsJSON, q.CorrectAnswerIndex, q.Explanation, time.Now(),
)
if err != nil {
return contextutils.WrapErrorf(err, "failed to insert question")
}
}
return tx.Commit()
}
// createSectionQuestionsInTx creates questions within an existing database transaction
func (s *StoryService) createSectionQuestionsInTx(ctx context.Context, tx *sql.Tx, sectionID uint, questions []models.StorySectionQuestionData) error {
if len(questions) == 0 {
return nil
}
for _, q := range questions {
query := `
INSERT INTO story_section_questions (
section_id, question_text, options, correct_answer_index, explanation, created_at
) VALUES ($1, $2, $3, $4, $5, $6)`
// Convert []string options to JSON for PostgreSQL JSONB column
optionsJSON, err := json.Marshal(q.Options)
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal options to JSON")
}
_, err = tx.ExecContext(ctx, query,
sectionID, q.QuestionText, optionsJSON, q.CorrectAnswerIndex, q.Explanation, time.Now(),
)
if err != nil {
return contextutils.WrapErrorf(err, "failed to insert question")
}
}
return nil
}
// GetRandomQuestions retrieves N random questions for a section
func (s *StoryService) GetRandomQuestions(ctx context.Context, sectionID uint, count int) ([]models.StorySectionQuestion, error) {
query := `
SELECT id, section_id, question_text, options, correct_answer_index, explanation, created_at
FROM story_section_questions
WHERE section_id = $1
ORDER BY RANDOM()
LIMIT $2`
rows, err := s.db.QueryContext(ctx, query, sectionID, count)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get random questions")
}
defer func() { _ = rows.Close() }()
questions := []models.StorySectionQuestion{}
for rows.Next() {
var question models.StorySectionQuestion
var optionsJSON []byte
err := rows.Scan(
&question.ID, &question.SectionID, &question.QuestionText, &optionsJSON,
&question.CorrectAnswerIndex, &question.Explanation, &question.CreatedAt,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to scan question")
}
// Unmarshal JSON options back to []string
err = json.Unmarshal(optionsJSON, &question.Options)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to unmarshal options from JSON")
}
questions = append(questions, question)
}
return questions, rows.Err()
}
// canGenerateSection checks if a new section can be generated for a story today by a specific generator
func (s *StoryService) canGenerateSection(ctx context.Context, storyID uint, generatorType models.GeneratorType) (response *models.StoryGenerationEligibilityResponse, err error) {
ctx, span := observability.TraceFunction(ctx, "story_service", "can_generate_section",
attribute.Int("story.id", int(storyID)),
observability.AttributeGenerationType(generatorType),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT status, last_section_generated_at, extra_generations_today
FROM stories
WHERE id = $1`
var status string
var lastGen *time.Time
var extraGenerationsToday int
err = s.db.QueryRowContext(ctx, query, storyID).Scan(&status, &lastGen, &extraGenerationsToday)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: "story not found",
}, nil
}
return nil, contextutils.WrapErrorf(err, "failed to get story")
}
// Check if story generation is enabled globally
if !s.config.Story.GenerationEnabled {
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: "story generation is disabled globally",
}, nil
}
// Check if story is active (active stories are by definition current)
if status != string(models.StoryStatusActive) {
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: "story is not active",
}, nil
}
// Check engagement-based generation if enabled and this is worker generation
// Manual user generation should always be allowed regardless of engagement
if s.config.Story.EngagementBasedGeneration && generatorType == models.GeneratorTypeWorker {
// Get the user ID for this story to check engagement
userIDQuery := "SELECT user_id FROM stories WHERE id = $1"
var userID uint
err = s.db.QueryRowContext(ctx, userIDQuery, storyID).Scan(&userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get user ID for story")
}
// Check if user has viewed the latest section
hasViewedLatest, err := s.HasUserViewedLatestSection(ctx, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to check user engagement")
}
if !hasViewedLatest {
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: "user has not viewed the latest section",
}, nil
}
}
// Check generation count for today by generator type
today := time.Now().Truncate(24 * time.Hour)
var sectionCount int
sectionQuery := `
SELECT COUNT(*)
FROM story_sections
WHERE story_id = $1 AND generation_date = $2 AND generated_by = $3`
err = s.db.QueryRowContext(ctx, sectionQuery, storyID, today, generatorType).Scan(§ionCount)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to check existing sections today by generator type")
}
span.SetAttributes(attribute.Int(fmt.Sprintf("section_count_%s", generatorType), sectionCount))
span.SetAttributes(attribute.Int("max_worker_generations_per_day", s.config.Story.MaxWorkerGenerationsPerDay))
span.SetAttributes(attribute.Int("max_user_generations_per_day", s.config.Story.MaxExtraGenerationsPerDay))
// Check limits based on generator type
switch generatorType {
case models.GeneratorTypeWorker:
// Worker can generate MaxWorkerGenerationsPerDay sections per day
if sectionCount >= s.config.Story.MaxWorkerGenerationsPerDay {
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: fmt.Sprintf("worker has reached daily generation limit (%d)", s.config.Story.MaxWorkerGenerationsPerDay),
}, nil
}
case models.GeneratorTypeUser:
if sectionCount >= s.config.Story.MaxExtraGenerationsPerDay {
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: fmt.Sprintf("user has reached daily generation limit (%d)", s.config.Story.MaxExtraGenerationsPerDay),
}, nil
}
default:
return &models.StoryGenerationEligibilityResponse{
CanGenerate: false,
Reason: "invalid generator type",
}, nil
}
// Allow generation if within limits
return &models.StoryGenerationEligibilityResponse{
CanGenerate: true,
}, nil
}
// UpdateLastGenerationTime sets the last section generation time for a story
func (s *StoryService) UpdateLastGenerationTime(ctx context.Context, storyID uint, generatorType models.GeneratorType) (err error) {
ctx, span := observability.TraceFunction(ctx, "story_service", "update_last_generation_time",
attribute.Int("story.id", int(storyID)),
observability.AttributeGenerationType(generatorType),
)
defer observability.FinishSpan(span, &err)
// Check if this is an extra generation (second generation today)
query := `
SELECT last_section_generated_at, extra_generations_today
FROM stories
WHERE id = $1`
var lastGen *time.Time
var extraGenerationsToday int
err = s.db.QueryRowContext(ctx, query, storyID).Scan(&lastGen, &extraGenerationsToday)
if err != nil {
return contextutils.WrapErrorf(err, "failed to get current generation info")
}
now := time.Now()
// Check if we already generated today and update accordingly
if lastGen != nil {
lastGenTime := lastGen.Truncate(24 * time.Hour)
today := now.Truncate(24 * time.Hour)
if lastGenTime.Equal(today) {
// Only increment counter for user generations
if generatorType == models.GeneratorTypeUser {
maxTotal := s.config.Story.MaxExtraGenerationsPerDay + 1
if extraGenerationsToday < maxTotal {
updateQuery := "UPDATE stories SET extra_generations_today = extra_generations_today + 1, last_section_generated_at = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, updateQuery, now, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to update generation time")
}
} else {
// Limit reached - just update timestamp
updateQuery := "UPDATE stories SET last_section_generated_at = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, updateQuery, now, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to update generation time")
}
}
} else {
// Worker generation - just update timestamp
updateQuery := "UPDATE stories SET last_section_generated_at = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, updateQuery, now, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to update generation time")
}
}
return nil
}
}
// First generation today - only increment counter for user generations
if generatorType == models.GeneratorTypeUser {
updateQuery := "UPDATE stories SET extra_generations_today = extra_generations_today + 1, last_section_generated_at = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, updateQuery, now, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to update generation time for first generation")
}
} else {
// Worker generation - just update timestamp
updateQuery := "UPDATE stories SET last_section_generated_at = $1, updated_at = NOW() WHERE id = $2"
_, err = s.db.ExecContext(ctx, updateQuery, now, storyID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to update generation time for first generation")
}
}
return nil
}
// RecordStorySectionView records that a user has viewed a story section
func (s *StoryService) RecordStorySectionView(ctx context.Context, userID, sectionID uint) (err error) {
ctx, span := observability.TraceFunction(ctx, "story_service", "record_section_view",
observability.AttributeUserID(int(userID)),
attribute.Int("section.id", int(sectionID)),
)
defer observability.FinishSpan(span, &err)
// Use UPSERT to either insert a new view or update the viewed_at timestamp if the view already exists
query := `
INSERT INTO story_section_views (user_id, section_id, viewed_at, created_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (user_id, section_id)
DO UPDATE SET viewed_at = NOW()`
_, err = s.db.ExecContext(ctx, query, userID, sectionID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to record story section view")
}
return nil
}
// HasUserViewedLatestSection checks if a user has viewed the latest section of their story
func (s *StoryService) HasUserViewedLatestSection(ctx context.Context, userID uint) (bool, error) {
ctx, span := observability.TraceFunction(ctx, "story_service", "has_user_viewed_latest_section",
observability.AttributeUserID(int(userID)),
)
defer observability.FinishSpan(span, nil)
// Get the user's current active story
story, err := s.GetCurrentStory(ctx, userID)
if err != nil {
return false, contextutils.WrapErrorf(err, "failed to get current story")
}
if story == nil {
// No current story - can't generate anything
return false, nil
}
if len(story.Sections) == 0 {
// Story exists but has no sections yet - allow first section generation
return true, nil
}
// Get the latest section (highest section number)
latestSection := story.Sections[len(story.Sections)-1]
// Check if user has viewed this section
query := `
SELECT EXISTS(
SELECT 1 FROM story_section_views
WHERE user_id = $1 AND section_id = $2
)`
var hasViewed bool
err = s.db.QueryRowContext(ctx, query, userID, latestSection.ID).Scan(&hasViewed)
if err != nil {
return false, contextutils.WrapErrorf(err, "failed to check if user viewed latest section")
}
return hasViewed, nil
}
// Helper methods
// getUserByID retrieves a user by their ID
func (s *StoryService) getUserByID(ctx context.Context, userID uint) (*models.User, error) {
query := "SELECT id, username, email, preferred_language, current_level, ai_provider, ai_model, ai_api_key, created_at, updated_at FROM users WHERE id = $1"
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Username, &user.Email, &user.PreferredLanguage,
&user.CurrentLevel, &user.AIProvider, &user.AIModel, &user.AIAPIKey,
&user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // User not found
}
return nil, contextutils.WrapErrorf(err, "failed to get user")
}
return &user, nil
}
// getArchivedStoryCount counts archived stories for a user
func (s *StoryService) getArchivedStoryCount(ctx context.Context, userID uint) (int, error) {
query := "SELECT COUNT(*) FROM stories WHERE user_id = $1 AND status = $2"
var count int
err := s.db.QueryRowContext(ctx, query, userID, models.StoryStatusArchived).Scan(&count)
if err != nil {
return 0, err
}
return count, nil
}
// getUserCurrentLevel retrieves the user's current language level
func (s *StoryService) getUserCurrentLevel(ctx context.Context, userID uint) (string, error) {
query := "SELECT current_level FROM users WHERE id = $1"
var level sql.NullString
err := s.db.QueryRowContext(ctx, query, userID).Scan(&level)
if err != nil {
return "", contextutils.WrapErrorf(err, "failed to get user")
}
if !level.Valid {
return "", contextutils.ErrorWithContextf("user has no current level set")
}
return level.String, nil
}
// validateStoryOwnership verifies that a user owns a story
func (s *StoryService) validateStoryOwnership(ctx context.Context, storyID, userID uint) error {
query := "SELECT COUNT(*) FROM stories WHERE id = $1 AND user_id = $2"
var count int
err := s.db.QueryRowContext(ctx, query, storyID, userID).Scan(&count)
if err != nil {
return contextutils.WrapErrorf(err, "failed to validate story ownership")
}
if count == 0 {
return contextutils.ErrorWithContextf("story not found or access denied")
}
return nil
}
// GetSectionLengthTarget returns the target word count for a story section
func (s *StoryService) GetSectionLengthTarget(level string, lengthPref *models.SectionLength) int {
return models.GetSectionLengthTarget(level, lengthPref)
}
// GetSectionLengthTargetWithLanguage returns the target word count with language-specific overrides
func (s *StoryService) GetSectionLengthTargetWithLanguage(language, level string, lengthPref *models.SectionLength) int {
// Check for language-specific overrides in config
if languageOverrides, exists := s.config.Story.SectionLengths.Overrides[language]; exists {
if levelTargets, exists := languageOverrides[level]; exists {
// Use the override if it exists for this level
if lengthPref != nil {
if target, exists := levelTargets[string(*lengthPref)]; exists {
return target
}
}
// Default to medium if no specific length preference
if target, exists := levelTargets["medium"]; exists {
return target
}
}
}
// Fall back to the default implementation
return models.GetSectionLengthTarget(level, lengthPref)
}
// SanitizeInput sanitizes user input for safe use in AI prompts
func (s *StoryService) SanitizeInput(input string) string {
return models.SanitizeInput(input)
}
// Database helper methods using sql.DB
// createStory inserts a new story into the database
func (s *StoryService) createStory(ctx context.Context, story *models.Story) error {
query := `
INSERT INTO stories (
user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id`
err := s.db.QueryRowContext(ctx, query,
story.UserID, story.Title, story.Language, story.Subject, story.AuthorStyle,
story.TimePeriod, story.Genre, story.Tone, story.CharacterNames,
story.CustomInstructions, story.SectionLengthOverride, story.Status,
story.CreatedAt, story.UpdatedAt,
).Scan(&story.ID)
return err
}
// Admin-only methods (no ownership checks)
// GetStoriesPaginated returns stories with optional filters for admin views
func (s *StoryService) GetStoriesPaginated(ctx context.Context, page, pageSize int, search, language, status string, userID *uint) ([]models.Story, int, error) {
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
// Build WHERE clauses dynamically
where := []string{"1=1"}
args := []interface{}{}
argIdx := 1
if search != "" {
where = append(where, fmt.Sprintf("(LOWER(title) LIKE $%d)", argIdx))
args = append(args, "%"+strings.ToLower(search)+"%")
argIdx++
}
if language != "" {
where = append(where, fmt.Sprintf("language = $%d", argIdx))
args = append(args, language)
argIdx++
}
if status != "" {
where = append(where, fmt.Sprintf("status = $%d", argIdx))
args = append(args, status)
argIdx++
}
if userID != nil {
where = append(where, fmt.Sprintf("user_id = $%d", argIdx))
args = append(args, *userID)
argIdx++
}
whereClause := strings.Join(where, " AND ")
// Count total
countQuery := "SELECT COUNT(*) FROM stories WHERE " + whereClause
var total int
if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, contextutils.WrapErrorf(err, "failed to count stories")
}
// Fetch rows
offset := (page - 1) * pageSize
listQuery := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
last_section_generated_at, created_at, updated_at
FROM stories
WHERE ` + whereClause + `
ORDER BY created_at DESC
LIMIT $` + fmt.Sprint(argIdx) + ` OFFSET $` + fmt.Sprint(argIdx+1)
args = append(args, pageSize, offset)
rows, err := s.db.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, contextutils.WrapErrorf(err, "failed to query stories")
}
defer func() { _ = rows.Close() }()
stories := []models.Story{}
for rows.Next() {
var story models.Story
if err := rows.Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
); err != nil {
return nil, 0, contextutils.WrapErrorf(err, "failed to scan story")
}
stories = append(stories, story)
}
return stories, total, rows.Err()
}
// GetStoryAdmin returns story with sections for admin (no ownership checks)
func (s *StoryService) GetStoryAdmin(ctx context.Context, storyID uint) (*models.StoryWithSections, error) {
query := `
SELECT id, user_id, title, language, subject, author_style, time_period, genre, tone,
character_names, custom_instructions, section_length_override, status,
last_section_generated_at, created_at, updated_at
FROM stories
WHERE id = $1`
var story models.Story
if err := s.db.QueryRowContext(ctx, query, storyID).Scan(
&story.ID, &story.UserID, &story.Title, &story.Language, &story.Subject,
&story.AuthorStyle, &story.TimePeriod, &story.Genre, &story.Tone,
&story.CharacterNames, &story.CustomInstructions, &story.SectionLengthOverride,
&story.Status, &story.LastSectionGeneratedAt,
&story.CreatedAt, &story.UpdatedAt,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, contextutils.ErrorWithContextf("story not found")
}
return nil, contextutils.WrapErrorf(err, "failed to get story")
}
sections, err := s.GetStorySections(ctx, story.ID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to load story sections")
}
return &models.StoryWithSections{Story: story, Sections: sections}, nil
}
// GetSectionAdmin returns section with questions for admin (no ownership checks)
func (s *StoryService) GetSectionAdmin(ctx context.Context, sectionID uint) (*models.StorySectionWithQuestions, error) {
query := `
SELECT id, story_id, section_number, content, language_level, word_count,
generated_by, generated_at, generation_date
FROM story_sections
WHERE id = $1`
var section models.StorySection
if err := s.db.QueryRowContext(ctx, query, sectionID).Scan(
§ion.ID, §ion.StoryID, §ion.SectionNumber, §ion.Content,
§ion.LanguageLevel, §ion.WordCount, §ion.GeneratedBy, §ion.GeneratedAt, §ion.GenerationDate,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, contextutils.ErrorWithContextf("section not found")
}
return nil, contextutils.WrapErrorf(err, "failed to get section")
}
questions, err := s.GetSectionQuestions(ctx, section.ID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to load section questions")
}
return &models.StorySectionWithQuestions{StorySection: section, Questions: questions}, nil
}
// createSection inserts a new section into the database
func (s *StoryService) createSection(ctx context.Context, section *models.StorySection) error {
query := `
INSERT INTO story_sections (
story_id, section_number, content, language_level, word_count, generated_by,
generated_at, generation_date
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id`
err := s.db.QueryRowContext(ctx, query,
section.StoryID, section.SectionNumber, section.Content, section.LanguageLevel,
section.WordCount, section.GeneratedBy, section.GeneratedAt, section.GenerationDate,
).Scan(§ion.ID)
return err
}
// createSectionInTx creates a section within an existing database transaction
func (s *StoryService) createSectionInTx(ctx context.Context, tx *sql.Tx, section *models.StorySection) error {
query := `
INSERT INTO story_sections (
story_id, section_number, content, language_level, word_count, generated_by,
generated_at, generation_date
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id`
err := tx.QueryRowContext(ctx, query,
section.StoryID, section.SectionNumber, section.Content, section.LanguageLevel,
section.WordCount, section.GeneratedBy, section.GeneratedAt, section.GenerationDate,
).Scan(§ion.ID)
return err
}
// GenerateStorySection generates a new section for a story using AI
func (s *StoryService) GenerateStorySection(ctx context.Context, storyID, userID uint, aiService AIServiceInterface, userAIConfig *models.UserAIConfig, generatorType models.GeneratorType) (*models.StorySectionWithQuestions, error) {
ctx, span := observability.TraceFunction(ctx, "story_service", "generate_section",
attribute.Int("story.id", int(storyID)),
observability.AttributeUserID(int(userID)),
observability.AttributeGenerationType(generatorType),
attribute.String("model", userAIConfig.Model),
attribute.String("provider", userAIConfig.Provider),
attribute.String("username", userAIConfig.Username),
)
defer observability.FinishSpan(span, nil)
// Get the story to verify ownership and get details
story, err := s.GetStory(ctx, storyID, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get story for generation")
}
span.SetAttributes(attribute.String("story.title", story.Title))
span.SetAttributes(attribute.String("story.language", story.Language))
span.SetAttributes(attribute.String("story.section_length_override", story.GetSectionLengthOverride()))
span.SetAttributes(attribute.String("story.subject", stringPtrToString(story.Subject)))
span.SetAttributes(attribute.String("story.author_style", stringPtrToString(story.AuthorStyle)))
span.SetAttributes(attribute.String("story.time_period", stringPtrToString(story.TimePeriod)))
span.SetAttributes(attribute.String("story.genre", stringPtrToString(story.Genre)))
span.SetAttributes(attribute.String("story.tone", stringPtrToString(story.Tone)))
span.SetAttributes(attribute.String("story.character_names", stringPtrToString(story.CharacterNames)))
span.SetAttributes(attribute.String("story.custom_instructions", stringPtrToString(story.CustomInstructions)))
// Check if generation is allowed today by this generator type
eligibility, err := s.canGenerateSection(ctx, storyID, generatorType)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to check generation eligibility")
}
if !eligibility.CanGenerate {
return nil, contextutils.WrapError(contextutils.ErrGenerationLimitReached, eligibility.Reason)
}
// Get user for AI configuration and language preferences
user, err := s.getUserByID(ctx, userID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get user")
}
if user == nil {
return nil, contextutils.ErrorWithContextf("user not found")
}
// Get all previous sections for context
previousSections, err := s.GetAllSectionsText(ctx, storyID)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get previous sections")
}
// Get the user's current language level (handle sql.NullString)
if !user.CurrentLevel.Valid {
return nil, contextutils.ErrorWithContextf("user level not found")
}
span.SetAttributes(attribute.String("story.level", user.CurrentLevel.String))
// Determine target length for this user's level
targetWords := s.GetSectionLengthTarget(user.CurrentLevel.String, story.SectionLengthOverride)
// Build the generation request
genReq := &models.StoryGenerationRequest{
UserID: userID,
StoryID: storyID,
Language: story.Language,
Level: user.CurrentLevel.String,
Title: story.Title,
Subject: story.Subject,
AuthorStyle: story.AuthorStyle,
TimePeriod: story.TimePeriod,
Genre: story.Genre,
Tone: story.Tone,
CharacterNames: story.CharacterNames,
CustomInstructions: story.CustomInstructions,
SectionLength: models.SectionLengthMedium, // Use medium as default
PreviousSections: previousSections,
IsFirstSection: len(story.Sections) == 0,
TargetWords: targetWords,
TargetSentences: targetWords / 15, // Rough estimate
}
// Generate the story section using AI
sectionContent, err := aiService.GenerateStorySection(ctx, userAIConfig, genReq)
if err != nil {
// Check if this is a context cancellation error
if ctx.Err() == context.DeadlineExceeded {
s.logger.Error(ctx, "Story section generation timed out", err, map[string]interface{}{
"story_id": storyID,
"user_id": userID,
})
return nil, contextutils.WrapErrorf(contextutils.ErrTimeout, "story generation timed out: %w", err)
}
return nil, contextutils.WrapErrorf(err, "failed to generate story section")
}
// Count words in the generated content
wordCount := len(strings.Fields(sectionContent))
// Start a database transaction to ensure atomicity of section and questions creation
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to begin transaction")
}
span.AddEvent("transaction_began")
var committed bool
defer func() {
if !committed {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction",
map[string]interface{}{
"story_id": storyID,
"user_id": userID,
"error": rollbackErr.Error(),
})
}
span.AddEvent("transaction_rolled_back")
}
}()
// Create the section within the transaction
section := &models.StorySection{
StoryID: storyID,
SectionNumber: 0, // Will be set by createSectionInTx
Content: sectionContent,
LanguageLevel: user.CurrentLevel.String,
WordCount: wordCount,
GeneratedBy: generatorType,
GeneratedAt: time.Now(),
GenerationDate: time.Now().Truncate(24 * time.Hour),
}
// Get the next section number within the transaction
var maxSectionNumber int
query := "SELECT COALESCE(MAX(section_number), 0) FROM story_sections WHERE story_id = $1"
err = tx.QueryRowContext(ctx, query, storyID).Scan(&maxSectionNumber)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get max section number")
}
section.SectionNumber = maxSectionNumber + 1
span.SetAttributes(attribute.Int("section.number", section.SectionNumber))
// Create the section in the database within the transaction
if err := s.createSectionInTx(ctx, tx, section); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to create story section")
}
span.AddEvent("section_created")
// Generate comprehension questions for the section
questionsReq := &models.StoryQuestionsRequest{
UserID: userID,
SectionID: section.ID,
Language: story.Language,
Level: user.CurrentLevel.String,
SectionText: sectionContent,
QuestionCount: s.config.Story.QuestionsPerSection,
}
var questions []*models.StorySectionQuestionData
questions, err = aiService.GenerateStoryQuestions(ctx, userAIConfig, questionsReq)
if err != nil {
// Check if this is a context cancellation error
if ctx.Err() == context.DeadlineExceeded {
s.logger.Warn(ctx, "Story questions generation timed out, continuing without questions",
map[string]interface{}{
"section_id": section.ID,
"story_id": storyID,
"user_id": userID,
"error": err.Error(),
})
} else {
s.logger.Warn(ctx, "Failed to generate questions for story section",
map[string]interface{}{
"section_id": section.ID,
"story_id": storyID,
"user_id": userID,
"error": err.Error(),
})
span.AddEvent("failed_to_generate_questions")
}
// Continue anyway - questions are nice to have but not critical
} else {
// Convert to database model slice (dereference pointers)
dbQuestions := make([]models.StorySectionQuestionData, len(questions))
for i, q := range questions {
dbQuestions[i] = *q
}
// Save the questions to the database within the same transaction
if err := s.createSectionQuestionsInTx(ctx, tx, section.ID, dbQuestions); err != nil {
s.logger.Warn(ctx, "Failed to save story questions",
map[string]interface{}{
"section_id": section.ID,
"story_id": storyID,
"user_id": userID,
"error": err.Error(),
})
span.AddEvent("failed_to_save_questions")
}
span.AddEvent("questions_saved")
}
// Commit the transaction
if err := tx.Commit(); err != nil {
span.AddEvent("failed_to_commit_transaction")
return nil, contextutils.WrapErrorf(err, "failed to commit transaction")
}
committed = true
span.AddEvent("transaction_committed")
// Update the story's last generation time
if err := s.UpdateLastGenerationTime(ctx, storyID, generatorType); err != nil {
s.logger.Warn(ctx, "Failed to update story generation time",
map[string]interface{}{
"story_id": storyID,
"user_id": userID,
"error": err.Error(),
})
}
s.logger.Info(ctx, "Story section generated successfully",
map[string]interface{}{
"story_id": storyID,
"section_id": section.ID,
"section_number": section.SectionNumber,
"user_id": userID,
"word_count": wordCount,
"question_count": len(questions),
})
// Load questions for the section
sectionQuestions, err := s.GetSectionQuestions(ctx, section.ID)
if err != nil {
s.logger.Warn(ctx, "Failed to load section questions after generation",
map[string]interface{}{
"section_id": section.ID,
"story_id": storyID,
"user_id": userID,
"error": err.Error(),
})
// Return section without questions rather than failing
sectionQuestions = []models.StorySectionQuestion{}
}
return &models.StorySectionWithQuestions{
StorySection: *section,
Questions: sectionQuestions,
}, nil
}
// Package services provides business logic services for the quiz application.
package services
import (
"context"
"database/sql"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// TestEmailService implements the Mailer interface for testing purposes
// It doesn't actually send emails but logs the operations and records them in the database
type TestEmailService struct {
cfg *config.Config
logger *observability.Logger
db *sql.DB
}
// NewTestEmailService creates a new TestEmailService instance
func NewTestEmailService(cfg *config.Config, logger *observability.Logger) *TestEmailService {
return &TestEmailService{
cfg: cfg,
logger: logger,
}
}
// NewTestEmailServiceWithDB creates a new TestEmailService instance with database connection
func NewTestEmailServiceWithDB(cfg *config.Config, logger *observability.Logger, db *sql.DB) *TestEmailService {
return &TestEmailService{
cfg: cfg,
logger: logger,
db: db,
}
}
// SendDailyReminder sends a daily reminder email to a user (test mode - just logs)
func (e *TestEmailService) SendDailyReminder(ctx context.Context, user *models.User) error {
ctx, span := otel.Tracer("test-email-service").Start(ctx, "SendDailyReminder",
trace.WithAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.email", user.Email.String),
),
)
defer span.End()
if !user.Email.Valid || user.Email.String == "" {
e.logger.Warn(ctx, "User has no email address, skipping daily reminder", map[string]interface{}{
"user_id": user.ID,
})
return nil
}
// Generate email data (same as real service) - not used in test mode but kept for consistency
_ = map[string]interface{}{
"Username": user.Username,
"QuizAppURL": e.cfg.Server.AppBaseURL,
"CurrentDate": time.Now().Format("January 2, 2006"),
"DailyGoal": 10,
"StreakDays": 5,
"TotalQuestions": 150,
"Level": "B1",
"Language": "Italian",
}
// Log the email operation instead of sending. Use the same subject as the
// real service to avoid confusion, but do NOT record a second entry in the
// database here â recording is handled by caller to ensure a single source
// of truth for sent notifications.
e.logger.Info(ctx, "TEST MODE: Would send daily reminder email", map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
"template": "daily_reminder",
"subject": "Time for your daily quiz! ð",
"test_mode": true,
})
return nil
}
// SendEmail sends a generic email with the given parameters (test mode - just logs)
func (e *TestEmailService) SendEmail(ctx context.Context, to, subject, templateName string, data map[string]interface{}) error {
ctx, span := otel.Tracer("test-email-service").Start(ctx, "SendEmail",
trace.WithAttributes(
attribute.String("email.to", to),
attribute.String("email.subject", subject),
attribute.String("email.template", templateName),
),
)
defer span.End()
// Log the email operation instead of sending
e.logger.Info(ctx, "TEST MODE: Would send email", map[string]interface{}{
"to": to,
"subject": subject,
"template": templateName,
"test_mode": true,
"data_keys": getMapKeys(data),
})
// Record the notification in the database if we have a DB connection
if e.db != nil {
// For test emails, we don't have a user ID, so we'll use 0
err := e.RecordSentNotification(ctx, 0, "test_email", subject, templateName, "sent", "")
if err != nil {
e.logger.Error(ctx, "Failed to record test notification", err, map[string]interface{}{
"to": to,
"template": templateName,
})
}
}
return nil
}
// RecordSentNotification records a sent notification in the database
func (e *TestEmailService) RecordSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) error {
ctx, span := otel.Tracer("test-email-service").Start(ctx, "RecordSentNotification",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("notification.type", notificationType),
attribute.String("notification.status", status),
),
)
defer span.End()
if e.db == nil {
e.logger.Warn(ctx, "No database connection available for recording notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
})
return nil
}
query := `
INSERT INTO sent_notifications (user_id, notification_type, subject, template_name, sent_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := e.db.ExecContext(ctx, query, userID, notificationType, subject, templateName, time.Now(), status, errorMessage)
if err != nil {
span.RecordError(err)
e.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return contextutils.WrapError(err, "failed to record sent notification")
}
e.logger.Info(ctx, "Recorded sent notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return nil
}
// IsEnabled returns whether email functionality is enabled (always true for test service)
func (e *TestEmailService) IsEnabled() bool {
return true
}
// getMapKeys returns the keys of a map as a slice of strings
func getMapKeys(data map[string]interface{}) []string {
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
return keys
}
//go:build integration
package services
import (
"context"
"database/sql"
"os"
"testing"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"github.com/stretchr/testify/require"
)
// SharedTestDBSetup provides a clean, isolated database for each integration test
// Uses the optimized CleanupTestDatabase function for consistent cleanup
func SharedTestDBSetup(t *testing.T) *sql.DB {
observabilityLogger := observability.NewLogger(&config.OpenTelemetryConfig{EnableLogging: false})
dbManager := database.NewManager(observabilityLogger)
// Require TEST_DATABASE_URL environment variable to be set
databaseURL := os.Getenv("TEST_DATABASE_URL")
if databaseURL == "" {
t.Fatal("TEST_DATABASE_URL environment variable must be set for integration tests")
}
db, err := dbManager.InitDB(databaseURL)
require.NoError(t, err)
// Use the optimized cleanup function
CleanupTestDatabase(db, t)
return db
}
// cleanupDatabase performs the core database cleanup operations
// This is the shared implementation used by both CleanupTestDatabase and SharedTestSuite.Cleanup
func cleanupDatabase(db *sql.DB, logger *observability.Logger) {
ctx := context.Background()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
if logger != nil {
logger.Error(ctx, "Failed to begin cleanup transaction", err)
}
return
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
// Fast cleanup with batched operations
cleanupQueries := []string{
"TRUNCATE TABLE user_responses CASCADE",
"TRUNCATE TABLE performance_metrics CASCADE",
"TRUNCATE TABLE user_question_metadata CASCADE",
"TRUNCATE TABLE question_priority_scores CASCADE",
"TRUNCATE TABLE user_learning_preferences CASCADE",
"TRUNCATE TABLE user_questions CASCADE",
"TRUNCATE TABLE questions CASCADE",
"TRUNCATE TABLE worker_status CASCADE",
"TRUNCATE TABLE worker_settings CASCADE",
"TRUNCATE TABLE user_api_keys CASCADE",
"TRUNCATE TABLE user_roles CASCADE",
"TRUNCATE TABLE question_reports CASCADE",
"TRUNCATE TABLE notification_errors CASCADE",
"TRUNCATE TABLE upcoming_notifications CASCADE",
"TRUNCATE TABLE sent_notifications CASCADE",
"TRUNCATE TABLE daily_question_assignments CASCADE",
"TRUNCATE TABLE story_sections CASCADE",
"TRUNCATE TABLE story_section_questions CASCADE",
"TRUNCATE TABLE stories CASCADE",
"TRUNCATE TABLE snippets CASCADE",
"TRUNCATE TABLE usage_stats CASCADE",
"TRUNCATE TABLE users CASCADE",
}
for _, query := range cleanupQueries {
_, err := tx.ExecContext(ctx, query)
if err != nil {
if logger != nil {
logger.Warn(ctx, "Could not execute cleanup query", map[string]interface{}{
"query": query,
})
}
}
}
// Reset sequences
sequenceQueries := []string{
"ALTER SEQUENCE users_id_seq RESTART WITH 1",
"ALTER SEQUENCE questions_id_seq RESTART WITH 1",
"ALTER SEQUENCE user_responses_id_seq RESTART WITH 1",
"ALTER SEQUENCE performance_metrics_id_seq RESTART WITH 1",
"ALTER SEQUENCE snippets_id_seq RESTART WITH 1",
}
for _, query := range sequenceQueries {
_, err := tx.ExecContext(ctx, query)
if err != nil {
if logger != nil {
logger.Warn(ctx, "Could not reset sequence", map[string]interface{}{
"query": query,
})
}
}
}
// Re-insert default worker settings
_, err = tx.ExecContext(ctx, `
INSERT INTO worker_settings (setting_key, setting_value, created_at, updated_at)
VALUES ('global_pause', 'false', NOW(), NOW())
ON CONFLICT (setting_key) DO NOTHING;
`)
if err != nil {
if logger != nil {
logger.Error(ctx, "Failed to insert worker settings", err)
}
}
err = tx.Commit()
if err != nil {
if logger != nil {
logger.Error(ctx, "Failed to commit cleanup transaction", err)
}
}
}
// CleanupTestDatabase cleans up the database for integration tests
// This function can be used by any integration test that needs to clean up the database
// Optimized to use batched transactions for better performance
func CleanupTestDatabase(db *sql.DB, t *testing.T) {
cleanupDatabase(db, nil)
}
package services
import (
"context"
"crypto/sha256"
"database/sql"
"fmt"
"sync"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
)
// TranslationCacheRepository defines the interface for translation cache operations
type TranslationCacheRepository interface {
// GetCachedTranslation retrieves a cached translation if it exists and is not expired
GetCachedTranslation(ctx context.Context, textHash, sourceLang, targetLang string) (*models.TranslationCache, error)
// SaveTranslation stores a translation in the cache with a 30-day expiration
SaveTranslation(ctx context.Context, textHash, originalText, sourceLang, targetLang, translatedText string) error
// CleanupExpiredTranslations removes expired translation cache entries
CleanupExpiredTranslations(ctx context.Context) (int64, error)
}
// TranslationCacheRepositoryImpl implements TranslationCacheRepository
type TranslationCacheRepositoryImpl struct {
db *sql.DB
logger *observability.Logger
}
// NewTranslationCacheRepository creates a new translation cache repository
func NewTranslationCacheRepository(db *sql.DB, logger *observability.Logger) TranslationCacheRepository {
return &TranslationCacheRepositoryImpl{
db: db,
logger: logger,
}
}
// HashText generates a SHA-256 hash of the input text
func HashText(text string) string {
hash := sha256.Sum256([]byte(text))
return fmt.Sprintf("%x", hash)
}
// GetCachedTranslation retrieves a cached translation if it exists and is not expired
func (r *TranslationCacheRepositoryImpl) GetCachedTranslation(ctx context.Context, textHash, sourceLang, targetLang string) (result *models.TranslationCache, err error) {
ctx, span := observability.TraceDatabaseFunction(ctx, "get_cached_translation",
attribute.String("cache.text_hash", textHash),
attribute.String("cache.source_language", sourceLang),
attribute.String("cache.target_language", targetLang),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, text_hash, original_text, source_language, target_language,
translated_text, created_at, expires_at
FROM translation_cache
WHERE text_hash = $1
AND source_language = $2
AND target_language = $3
AND expires_at > NOW()
`
cache := &models.TranslationCache{}
err = r.db.QueryRowContext(ctx, query, textHash, sourceLang, targetLang).Scan(
&cache.ID,
&cache.TextHash,
&cache.OriginalText,
&cache.SourceLanguage,
&cache.TargetLanguage,
&cache.TranslatedText,
&cache.CreatedAt,
&cache.ExpiresAt,
)
if err == sql.ErrNoRows {
span.SetAttributes(attribute.Bool("cache.found", false))
return nil, nil // Not found or expired
}
if err != nil {
err = contextutils.WrapError(err, "failed to query translation cache")
return nil, err
}
span.SetAttributes(attribute.Bool("cache.found", true))
return cache, nil
}
// SaveTranslation stores a translation in the cache with a 30-day expiration
func (r *TranslationCacheRepositoryImpl) SaveTranslation(ctx context.Context, textHash, originalText, sourceLang, targetLang, translatedText string) (err error) {
ctx, span := observability.TraceDatabaseFunction(ctx, "save_translation_cache",
attribute.String("cache.text_hash", textHash),
attribute.String("cache.source_language", sourceLang),
attribute.String("cache.target_language", targetLang),
attribute.Int("cache.original_text_length", len(originalText)),
attribute.Int("cache.translated_text_length", len(translatedText)),
)
defer observability.FinishSpan(span, &err)
expiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days from now
query := `
INSERT INTO translation_cache (text_hash, original_text, source_language, target_language, translated_text, expires_at)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (text_hash, source_language, target_language)
DO UPDATE SET
translated_text = EXCLUDED.translated_text,
expires_at = EXCLUDED.expires_at,
created_at = CURRENT_TIMESTAMP
`
_, err = r.db.ExecContext(ctx, query, textHash, originalText, sourceLang, targetLang, translatedText, expiresAt)
if err != nil {
err = contextutils.WrapError(err, "failed to save translation to cache")
return err
}
span.SetAttributes(
attribute.String("cache.expires_at", expiresAt.Format(time.RFC3339)),
)
return nil
}
// CleanupExpiredTranslations removes expired translation cache entries
func (r *TranslationCacheRepositoryImpl) CleanupExpiredTranslations(ctx context.Context) (count int64, err error) {
ctx, span := observability.TraceDatabaseFunction(ctx, "cleanup_expired_translations")
defer observability.FinishSpan(span, &err)
query := `DELETE FROM translation_cache WHERE expires_at < NOW()`
result, err := r.db.ExecContext(ctx, query)
if err != nil {
err = contextutils.WrapError(err, "failed to cleanup expired translations")
return 0, err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
err = contextutils.WrapError(err, "failed to get rows affected")
return 0, err
}
span.SetAttributes(attribute.Int64("cache.deleted_count", rowsAffected))
r.logger.Info(ctx, "Cleaned up expired translation cache entries", map[string]interface{}{
"deleted_count": rowsAffected,
})
return rowsAffected, nil
}
// InMemoryTranslationCacheRepository is an in-memory implementation for testing
type InMemoryTranslationCacheRepository struct {
cache map[string]*models.TranslationCache
mu sync.RWMutex
}
// NewInMemoryTranslationCacheRepository creates a new in-memory translation cache repository
func NewInMemoryTranslationCacheRepository() *InMemoryTranslationCacheRepository {
return &InMemoryTranslationCacheRepository{
cache: make(map[string]*models.TranslationCache),
}
}
// GetCachedTranslation retrieves a cached translation from the in-memory cache
func (m *InMemoryTranslationCacheRepository) GetCachedTranslation(_ context.Context, textHash, sourceLang, targetLang string) (*models.TranslationCache, error) {
m.mu.RLock()
defer m.mu.RUnlock()
key := textHash + "|" + sourceLang + "|" + targetLang
cached, exists := m.cache[key]
if !exists {
return nil, nil
}
// Check if expired
if time.Now().After(cached.ExpiresAt) {
return nil, nil
}
return cached, nil
}
// SaveTranslation saves a translation to the in-memory cache
func (m *InMemoryTranslationCacheRepository) SaveTranslation(_ context.Context, textHash, originalText, sourceLang, targetLang, translatedText string) error {
m.mu.Lock()
defer m.mu.Unlock()
key := textHash + "|" + sourceLang + "|" + targetLang
m.cache[key] = &models.TranslationCache{
TextHash: textHash,
OriginalText: originalText,
SourceLanguage: sourceLang,
TargetLanguage: targetLang,
TranslatedText: translatedText,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(30 * 24 * time.Hour),
}
return nil
}
// CleanupExpiredTranslations removes expired entries from the in-memory cache
func (m *InMemoryTranslationCacheRepository) CleanupExpiredTranslations(_ context.Context) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
deleted := int64(0)
for key, cached := range m.cache {
if now.After(cached.ExpiresAt) {
delete(m.cache, key)
deleted++
}
}
return deleted, nil
}
package services
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/serviceinterfaces"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
)
// TranslationServiceInterface defines the interface for translation services
type TranslationServiceInterface = serviceinterfaces.TranslationService
// GoogleTranslationService handles translation requests using Google Translate API
type GoogleTranslationService struct {
config *config.Config
httpClient *http.Client
usageStatsSvc UsageStatsServiceInterface
cacheRepo TranslationCacheRepository
logger *observability.Logger
}
// NewGoogleTranslationService creates a new Google translation service instance
func NewGoogleTranslationService(config *config.Config, usageStatsSvc UsageStatsServiceInterface, cacheRepo TranslationCacheRepository, logger *observability.Logger) *GoogleTranslationService {
return &GoogleTranslationService{
config: config,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
usageStatsSvc: usageStatsSvc,
cacheRepo: cacheRepo,
logger: logger,
}
}
// GoogleTranslateRequest represents the request format for Google Translate API
type GoogleTranslateRequest struct {
Q []string `json:"q"`
Target string `json:"target"`
Source string `json:"source,omitempty"`
Format string `json:"format"`
}
// normalizeLanguageCode converts language names to ISO codes for Google Translate API
func normalizeLanguageCode(lang string, languageLevels map[string]config.LanguageLevelConfig) string {
// Check if it's a language name in our config
for languageName, levelConfig := range languageLevels {
if strings.EqualFold(languageName, lang) {
return levelConfig.Code
}
}
// If it's already a valid ISO code or unknown, return as-is
return lang
}
// GoogleTranslateResponse represents the response format from Google Translate API
type GoogleTranslateResponse struct {
Data struct {
Translations []struct {
TranslatedText string `json:"translatedText"`
DetectedSourceLanguage string `json:"detectedSourceLanguage"`
} `json:"translations"`
} `json:"data"`
}
// Translate translates text using the configured translation provider
func (s *GoogleTranslationService) Translate(ctx context.Context, req serviceinterfaces.TranslateRequest) (result *serviceinterfaces.TranslateResponse, err error) {
ctx, span := observability.TraceTranslationFunction(ctx, "translate",
attribute.String("translation.target_language", req.TargetLanguage),
attribute.String("translation.source_language", req.SourceLanguage),
attribute.Int("translation.text_length", len(req.Text)),
)
defer observability.FinishSpan(span, &err)
if !s.config.Translation.Enabled {
return nil, contextutils.NewAppError(contextutils.ErrorCodeServiceUnavailable, contextutils.SeverityError, "Translation service is disabled", "")
}
// Get provider config for usage stats and quota checking
providerConfig, exists := s.config.Translation.Providers[s.config.Translation.DefaultProvider]
if !exists {
err = contextutils.NewAppError(contextutils.ErrorCodeServiceUnavailable, contextutils.SeverityError, "Translation provider not configured", "")
return nil, err
}
span.SetAttributes(attribute.String("translation.provider", providerConfig.Code))
// Generate hash for cache lookup
textHash := HashText(req.Text)
span.SetAttributes(attribute.String("cache.text_hash", textHash))
// Normalize source language for consistent cache lookup
normalizedSourceLang := normalizeLanguageCode(req.SourceLanguage, s.config.LanguageLevels)
normalizedTargetLang := normalizeLanguageCode(req.TargetLanguage, s.config.LanguageLevels)
// Check cache first (provider-agnostic)
cachedTranslation, err := s.cacheRepo.GetCachedTranslation(ctx, textHash, normalizedSourceLang, normalizedTargetLang)
if err != nil {
// Log cache error but don't fail the translation request
s.logger.Error(ctx, "Failed to check translation cache", err, map[string]interface{}{
"text_hash": textHash,
"source_language": normalizedSourceLang,
"target_language": normalizedTargetLang,
})
} else if cachedTranslation != nil {
// Cache hit - return cached translation
span.SetAttributes(
attribute.Bool("cache.hit", true),
attribute.String("cache.created_at", cachedTranslation.CreatedAt.Format(time.RFC3339)),
)
// Record cache hit in usage stats
if err := s.usageStatsSvc.RecordUsage(ctx, providerConfig.Code, "translation_cache_hit", len(req.Text), 1); err != nil {
s.logger.Error(ctx, "Failed to record translation cache hit", err)
}
return &serviceinterfaces.TranslateResponse{
TranslatedText: cachedTranslation.TranslatedText,
SourceLanguage: cachedTranslation.SourceLanguage,
TargetLanguage: cachedTranslation.TargetLanguage,
}, nil
}
// Cache miss - proceed with API call
span.SetAttributes(attribute.Bool("cache.hit", false))
// Record cache miss in usage stats
if err := s.usageStatsSvc.RecordUsage(ctx, providerConfig.Code, "translation_cache_miss", 0, 1); err != nil {
s.logger.Error(ctx, "Failed to record translation cache miss", err)
}
// Check quota before making the request
if err := s.usageStatsSvc.CheckQuota(ctx, providerConfig.Code, "translation", len(req.Text)); err != nil {
return nil, err
}
if providerConfig.APIKey == "" {
err = contextutils.NewAppError(contextutils.ErrorCodeServiceUnavailable, contextutils.SeverityError, "Google Translate API key not configured", "")
return nil, err
}
if req.SourceLanguage == "" || req.TargetLanguage == "" {
err = contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Source and target language are required", "")
return nil, err
}
if len(req.Text) == 0 {
err = contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Text cannot be empty", "")
return nil, err
}
if len(req.Text) > providerConfig.MaxTextLength {
err = contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, fmt.Sprintf("Text cannot exceed %d characters", providerConfig.MaxTextLength), "")
return nil, err
}
// Prepare request - use normalized language codes for Google Translate API
requestBody := GoogleTranslateRequest{
Q: []string{req.Text},
Target: normalizedTargetLang,
Source: normalizedSourceLang,
Format: "text",
}
jsonBody, err := json.Marshal(requestBody)
if err != nil {
err = contextutils.WrapError(err, "failed to marshal request")
return nil, err
}
// Build URL
url := fmt.Sprintf("%s%s?key=%s", providerConfig.BaseURL, providerConfig.APIEndpoint, providerConfig.APIKey)
// Make request
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
err = contextutils.WrapError(err, "failed to create request")
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(httpReq.WithContext(ctx))
if err != nil {
err = contextutils.WrapError(err, "translation request failed")
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err = contextutils.NewAppError(contextutils.ErrorCodeServiceUnavailable, contextutils.SeverityError,
fmt.Sprintf("Google Translate API error: %d - %s", resp.StatusCode, string(body)), "")
return nil, err
}
// Parse response
var googleResp GoogleTranslateResponse
if err := json.NewDecoder(resp.Body).Decode(&googleResp); err != nil {
err = contextutils.WrapError(err, "failed to decode response")
return nil, err
}
if len(googleResp.Data.Translations) == 0 {
err = contextutils.NewAppError(contextutils.ErrorCodeServiceUnavailable, contextutils.SeverityError, "No translation returned from Google Translate API", "")
return nil, err
}
translation := googleResp.Data.Translations[0]
result = &serviceinterfaces.TranslateResponse{
TranslatedText: translation.TranslatedText,
SourceLanguage: normalizedSourceLang,
TargetLanguage: normalizedTargetLang,
}
// Record usage after successful translation
if err := s.usageStatsSvc.RecordUsage(ctx, providerConfig.Code, "translation", len(req.Text), 1); err != nil {
// Log the error but don't fail the translation request
// The translation was successful, we just couldn't record the usage
// This is a non-critical error that should be logged for monitoring
s.logger.Error(ctx, "Failed to record translation usage", err, map[string]interface{}{
"service": providerConfig.Code,
"usage_type": "translation",
"characters": len(req.Text),
"requests": 1,
})
}
// Save translation to cache using the normalized source language
if err := s.cacheRepo.SaveTranslation(ctx, textHash, req.Text, result.SourceLanguage, req.TargetLanguage, result.TranslatedText); err != nil {
// Log the error but don't fail the translation request
span.SetAttributes(attribute.Bool("cache.save_error", true))
s.logger.Error(ctx, "Failed to save translation to cache", err, map[string]interface{}{
"text_hash": textHash,
"source_language": result.SourceLanguage,
"target_language": req.TargetLanguage,
})
} else {
span.SetAttributes(attribute.Bool("cache.saved", true))
}
return result, nil
}
// ValidateLanguageCode validates that a language code is properly formatted
func (s *GoogleTranslationService) ValidateLanguageCode(langCode string) error {
if len(langCode) < 2 || len(langCode) > 10 {
return contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Language code must be 2-10 characters", "")
}
// Basic validation - should be alphanumeric with possible hyphens
for _, char := range langCode {
if (char < 'a' || char > 'z') && (char < 'A' || char > 'Z') && (char < '0' || char > '9') && char != '-' {
return contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Invalid language code format", "")
}
}
return nil
}
// GetSupportedLanguages returns a list of supported target languages for translation
func (s *GoogleTranslationService) GetSupportedLanguages() []string {
// Common languages supported by Google Translate API
return []string{
"af", "sq", "am", "ar", "hy", "az", "eu", "be", "bn", "bs", "bg", "ca", "ceb", "ny", "zh", "zh-CN", "zh-TW",
"co", "hr", "cs", "da", "nl", "en", "eo", "et", "tl", "fi", "fr", "fy", "gl", "ka", "de", "el", "gu", "ht",
"ha", "haw", "iw", "hi", "hmn", "hu", "is", "ig", "id", "ga", "it", "ja", "jw", "kn", "kk", "km", "ko", "ku",
"ky", "lo", "la", "lv", "lt", "lb", "mk", "mg", "ms", "ml", "mt", "mi", "mr", "mn", "my", "ne", "no", "ps",
"fa", "pl", "pt", "pa", "ro", "ru", "sm", "gd", "sr", "st", "sn", "sd", "si", "sk", "sl", "so", "es", "su",
"sw", "sv", "tg", "ta", "te", "th", "tr", "uk", "ur", "uz", "vi", "cy", "xh", "yi", "yo", "zu",
}
}
// NoopTranslationService is a no-operation implementation for testing and development
type NoopTranslationService struct{}
// NewNoopTranslationService creates a new noop translation service instance
func NewNoopTranslationService() *NoopTranslationService {
return &NoopTranslationService{}
}
// Translate returns the original text unchanged (no-op)
func (s *NoopTranslationService) Translate(_ context.Context, req serviceinterfaces.TranslateRequest) (*serviceinterfaces.TranslateResponse, error) {
return &serviceinterfaces.TranslateResponse{
TranslatedText: req.Text,
SourceLanguage: req.SourceLanguage,
TargetLanguage: req.TargetLanguage,
Confidence: 1.0,
}, nil
}
// ValidateLanguageCode validates that a language code is properly formatted
func (s *NoopTranslationService) ValidateLanguageCode(langCode string) error {
if len(langCode) < 2 || len(langCode) > 10 {
return contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Language code must be 2-10 characters", "")
}
// Basic validation - should be alphanumeric with possible hyphens
for _, char := range langCode {
if (char < 'a' || char > 'z') && (char < 'A' || char > 'Z') && (char < '0' || char > '9') && char != '-' {
return contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityError, "Invalid language code format", "")
}
}
return nil
}
// GetSupportedLanguages returns a list of supported target languages for translation
func (s *NoopTranslationService) GetSupportedLanguages() []string {
// Return a subset of common languages for testing
return []string{
"en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh",
}
}
// NewTranslationService creates a translation service based on configuration
// For testing environments, it returns a noop service if translation is disabled
// For production, it returns a Google translation service if properly configured
func NewTranslationService(config *config.Config, usageStatsSvc UsageStatsServiceInterface, cacheRepo TranslationCacheRepository, logger *observability.Logger) TranslationServiceInterface {
if !config.Translation.Enabled {
return NewNoopTranslationService()
}
providerConfig, exists := config.Translation.Providers[config.Translation.DefaultProvider]
if !exists {
// Fallback to noop if provider not configured
return NewNoopTranslationService()
}
switch providerConfig.Code {
case "google":
return NewGoogleTranslationService(config, usageStatsSvc, cacheRepo, logger)
default:
// Fallback to noop for unsupported providers
return NewNoopTranslationService()
}
}
package services
import (
"context"
"database/sql"
"fmt"
"time"
"quizapp/internal/config"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
openapi_types "github.com/oapi-codegen/runtime/types"
"go.opentelemetry.io/otel/attribute"
)
// UsageStatsServiceInterface defines the interface for usage statistics tracking
type UsageStatsServiceInterface interface {
// CheckQuota checks if a translation request would exceed the monthly quota
CheckQuota(ctx context.Context, serviceName, usageType string, characters int) error
// RecordUsage records the usage of a translation service
RecordUsage(ctx context.Context, serviceName, usageType string, characters, requests int) error
// GetCurrentMonthUsage returns the current month's usage for a service and type
GetCurrentMonthUsage(ctx context.Context, serviceName, usageType string) (*UsageStats, error)
// GetMonthlyQuota returns the monthly quota for a service
GetMonthlyQuota(serviceName string) int64
// GetAllUsageStats returns all usage statistics (for admin interface)
GetAllUsageStats(ctx context.Context) ([]*UsageStats, error)
// GetUsageStatsByService returns usage statistics for a specific service
GetUsageStatsByService(ctx context.Context, serviceName string) ([]*UsageStats, error)
// GetUsageStatsByMonth returns usage statistics for a specific month
GetUsageStatsByMonth(ctx context.Context, year, month int) ([]*UsageStats, error)
// AI Token usage tracking for users
// RecordUserAITokenUsage records AI token usage for a specific user
RecordUserAITokenUsage(ctx context.Context, userID int, apiKeyID *int, provider, model, usageType string, promptTokens, completionTokens, totalTokens, requests int) error
// GetUserAITokenUsageStats returns AI token usage statistics for a specific user
GetUserAITokenUsageStats(ctx context.Context, userID int, startDate, endDate time.Time) ([]*UserUsageStats, error)
// GetUserAITokenUsageStatsByDay returns daily aggregated AI token usage for a user
GetUserAITokenUsageStatsByDay(ctx context.Context, userID int, startDate, endDate time.Time) ([]*UserUsageStatsDaily, error)
// GetUserAITokenUsageStatsByHour returns hourly aggregated AI token usage for a user on a specific day
GetUserAITokenUsageStatsByHour(ctx context.Context, userID int, date time.Time) ([]*UserUsageStatsHourly, error)
}
// UsageStats represents usage statistics for a service in a given month
type UsageStats struct {
ID int `json:"id"`
ServiceName string `json:"service_name"`
UsageType string `json:"usage_type"`
UsageMonth time.Time `json:"usage_month"`
CharactersUsed int `json:"characters_used"`
RequestsMade int `json:"requests_made"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UserUsageStats represents detailed usage statistics for a user
type UserUsageStats struct {
ID int `json:"id"`
UserID int `json:"user_id"`
APIKeyID *int `json:"api_key_id,omitempty"`
UsageDate time.Time `json:"usage_date"`
UsageHour int `json:"usage_hour"`
ServiceName string `json:"service_name"`
Provider string `json:"provider"`
Model string `json:"model"`
UsageType string `json:"usage_type"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
RequestsMade int `json:"requests_made"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UserUsageStatsDaily represents daily aggregated usage for a user
type UserUsageStatsDaily struct {
UsageDate openapi_types.Date `json:"usage_date"`
ServiceName string `json:"service_name"`
Provider string `json:"provider"`
Model string `json:"model"`
UsageType string `json:"usage_type"`
TotalPromptTokens int `json:"total_prompt_tokens"`
TotalCompletionTokens int `json:"total_completion_tokens"`
TotalTokens int `json:"total_tokens"`
TotalRequests int `json:"total_requests"`
}
// UserUsageStatsHourly represents hourly usage for a user on a specific day
type UserUsageStatsHourly struct {
UsageHour int `json:"usage_hour"`
ServiceName string `json:"service_name"`
Provider string `json:"provider"`
Model string `json:"model"`
UsageType string `json:"usage_type"`
TotalPromptTokens int `json:"total_prompt_tokens"`
TotalCompletionTokens int `json:"total_completion_tokens"`
TotalTokens int `json:"total_tokens"`
TotalRequests int `json:"total_requests"`
}
// UsageStatsService handles usage statistics tracking and quota management
type UsageStatsService struct {
config *config.Config
db *sql.DB
logger *observability.Logger
}
// NewUsageStatsService creates a new usage stats service
func NewUsageStatsService(config *config.Config, db *sql.DB, logger *observability.Logger) *UsageStatsService {
return &UsageStatsService{
config: config,
db: db,
logger: logger,
}
}
// CheckQuota checks if a translation request would exceed the monthly quota
func (s *UsageStatsService) CheckQuota(ctx context.Context, serviceName, usageType string, characters int) (err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "check_quota",
attribute.String("service_name", serviceName),
attribute.String("usage_type", usageType),
attribute.Int("characters", characters),
)
defer observability.FinishSpan(span, &err)
if !s.config.Translation.Quota.Enabled {
return nil // Quota checking disabled
}
currentUsage, err := s.GetCurrentMonthUsage(ctx, serviceName, usageType)
if err != nil {
return contextutils.WrapError(err, "failed to get current usage")
}
quota := s.GetMonthlyQuota(serviceName)
newTotal := currentUsage.CharactersUsed + characters
if newTotal > int(quota) {
return contextutils.NewAppError(
contextutils.ErrorCodeQuotaExceeded,
contextutils.SeverityWarn,
fmt.Sprintf("Monthly quota exceeded for %s %s service. Used: %d/%d characters",
serviceName, usageType, newTotal, quota),
"",
)
}
return nil
}
// RecordUsage records the usage of a translation service
func (s *UsageStatsService) RecordUsage(ctx context.Context, serviceName, usageType string, characters, requests int) (err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "record_usage",
attribute.String("service_name", serviceName),
attribute.String("usage_type", usageType),
attribute.Int("characters", characters),
attribute.Int("requests", requests),
)
defer observability.FinishSpan(span, &err)
currentMonth := time.Now().UTC().Truncate(24*time.Hour).AddDate(0, 0, -time.Now().UTC().Day()+1) // First day of current month
query := `
INSERT INTO usage_stats (service_name, usage_type, usage_month, characters_used, requests_made, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (service_name, usage_type, usage_month)
DO UPDATE SET
characters_used = usage_stats.characters_used + $4,
requests_made = usage_stats.requests_made + $5,
updated_at = NOW()`
_, err = s.db.ExecContext(ctx, query, serviceName, usageType, currentMonth, characters, requests)
if err != nil {
return contextutils.WrapError(err, "failed to record usage")
}
return nil
}
// RecordUserAITokenUsage records AI token usage for a specific user
func (s *UsageStatsService) RecordUserAITokenUsage(ctx context.Context, userID int, apiKeyID *int, provider, model, usageType string, promptTokens, completionTokens, totalTokens, requests int) (err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "record_user_ai_token_usage",
attribute.Int("user_id", userID),
attribute.String("provider", provider),
attribute.String("model", model),
attribute.String("usage_type", usageType),
attribute.Int("prompt_tokens", promptTokens),
attribute.Int("completion_tokens", completionTokens),
attribute.Int("total_tokens", totalTokens),
attribute.Int("requests", requests),
)
defer observability.FinishSpan(span, &err)
now := time.Now()
usageDate := now.Truncate(24 * time.Hour) // Start of day
usageHour := now.Hour()
query := `
INSERT INTO user_usage_stats (user_id, api_key_id, usage_date, usage_hour, service_name, provider, model, usage_type, prompt_tokens, completion_tokens, total_tokens, requests_made, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NOW())
ON CONFLICT (user_id, api_key_id, usage_date, usage_hour, service_name, provider, model, usage_type)
DO UPDATE SET
prompt_tokens = user_usage_stats.prompt_tokens + $9,
completion_tokens = user_usage_stats.completion_tokens + $10,
total_tokens = user_usage_stats.total_tokens + $11,
requests_made = user_usage_stats.requests_made + $12,
updated_at = NOW()`
_, err = s.db.ExecContext(ctx, query, userID, apiKeyID, usageDate, usageHour, "ai", provider, model, usageType, promptTokens, completionTokens, totalTokens, requests)
if err != nil {
return contextutils.WrapError(err, "failed to record user ai token usage")
}
return nil
}
// GetCurrentMonthUsage returns the current month's usage for a service and type
func (s *UsageStatsService) GetCurrentMonthUsage(ctx context.Context, serviceName, usageType string) (stats *UsageStats, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_current_month_usage",
attribute.String("service_name", serviceName),
attribute.String("usage_type", usageType),
)
defer observability.FinishSpan(span, &err)
currentMonth := time.Now().UTC().Truncate(24*time.Hour).AddDate(0, 0, -time.Now().UTC().Day()+1) // First day of current month
query := `
SELECT id, service_name, usage_type, usage_month, characters_used, requests_made, created_at, updated_at
FROM usage_stats
WHERE service_name = $1 AND usage_type = $2 AND usage_month = $3`
stats = &UsageStats{}
err = s.db.QueryRowContext(ctx, query, serviceName, usageType, currentMonth).Scan(
&stats.ID, &stats.ServiceName, &stats.UsageType, &stats.UsageMonth,
&stats.CharactersUsed, &stats.RequestsMade, &stats.CreatedAt, &stats.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
// Return empty stats for new service/month
return &UsageStats{
ServiceName: serviceName,
UsageType: usageType,
UsageMonth: currentMonth,
CharactersUsed: 0,
RequestsMade: 0,
}, nil
}
return nil, contextutils.WrapError(err, "failed to get usage stats")
}
return stats, nil
}
// GetMonthlyQuota returns the monthly quota for a service
func (s *UsageStatsService) GetMonthlyQuota(serviceName string) int64 {
if !s.config.Translation.Quota.Enabled {
return 0 // No quota limit when disabled
}
switch serviceName {
case "google":
return s.config.Translation.Quota.GoogleMonthlyQuota
default:
return s.config.Translation.Quota.DefaultMonthlyQuota
}
}
// GetAllUsageStats returns all usage statistics (for admin interface)
func (s *UsageStatsService) GetAllUsageStats(ctx context.Context) (stats []*UsageStats, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_all_usage_stats")
defer observability.FinishSpan(span, &err)
query := `
SELECT id, service_name, usage_type, usage_month, characters_used, requests_made, created_at, updated_at
FROM usage_stats
ORDER BY usage_month DESC, service_name, usage_type`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query usage stats")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
stats = []*UsageStats{}
for rows.Next() {
var stat UsageStats
err := rows.Scan(
&stat.ID, &stat.ServiceName, &stat.UsageType, &stat.UsageMonth,
&stat.CharactersUsed, &stat.RequestsMade, &stat.CreatedAt, &stat.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan usage stats")
}
stats = append(stats, &stat)
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating usage stats")
}
return stats, nil
}
// GetUsageStatsByService returns usage statistics for a specific service
func (s *UsageStatsService) GetUsageStatsByService(ctx context.Context, serviceName string) (stats []*UsageStats, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_usage_stats_by_service",
attribute.String("service_name", serviceName),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, service_name, usage_type, usage_month, characters_used, requests_made, created_at, updated_at
FROM usage_stats
WHERE service_name = $1
ORDER BY usage_month DESC, usage_type`
rows, err := s.db.QueryContext(ctx, query, serviceName)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query usage stats by service")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
stats = []*UsageStats{}
for rows.Next() {
var stat UsageStats
err := rows.Scan(
&stat.ID, &stat.ServiceName, &stat.UsageType, &stat.UsageMonth,
&stat.CharactersUsed, &stat.RequestsMade, &stat.CreatedAt, &stat.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan usage stats")
}
stats = append(stats, &stat)
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating usage stats")
}
return stats, nil
}
// GetUsageStatsByMonth returns usage statistics for a specific month
func (s *UsageStatsService) GetUsageStatsByMonth(ctx context.Context, year, month int) (stats []*UsageStats, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_usage_stats_by_month",
attribute.Int("year", year),
attribute.Int("month", month),
)
defer observability.FinishSpan(span, &err)
// Create date for the first day of the specified month
targetMonth := time.Date(year, time.Month(month), 1, 0, 0, 0, 0, time.UTC)
query := `
SELECT id, service_name, usage_type, usage_month, characters_used, requests_made, created_at, updated_at
FROM usage_stats
WHERE usage_month = $1
ORDER BY service_name, usage_type`
rows, err := s.db.QueryContext(ctx, query, targetMonth)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query usage stats by month")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
stats = []*UsageStats{}
for rows.Next() {
var stat UsageStats
err := rows.Scan(
&stat.ID, &stat.ServiceName, &stat.UsageType, &stat.UsageMonth,
&stat.CharactersUsed, &stat.RequestsMade, &stat.CreatedAt, &stat.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan usage stats")
}
stats = append(stats, &stat)
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating usage stats")
}
return stats, nil
}
// GetUserAITokenUsageStats returns AI token usage statistics for a specific user
func (s *UsageStatsService) GetUserAITokenUsageStats(ctx context.Context, userID int, startDate, endDate time.Time) (stats []*UserUsageStats, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_user_ai_token_usage_stats",
attribute.Int("user_id", userID),
attribute.String("start_date", startDate.Format("2006-01-02")),
attribute.String("end_date", endDate.Format("2006-01-02")),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, user_id, api_key_id, usage_date, usage_hour, service_name, provider, model, usage_type, prompt_tokens, completion_tokens, total_tokens, requests_made, created_at, updated_at
FROM user_usage_stats
WHERE user_id = $1 AND usage_date >= $2 AND usage_date <= $3
ORDER BY usage_date DESC, usage_hour DESC`
rows, err := s.db.QueryContext(ctx, query, userID, startDate, endDate)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query user usage stats")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close user usage stats query", map[string]interface{}{
"error": closeErr.Error(),
})
}
}()
stats = []*UserUsageStats{}
for rows.Next() {
var stat UserUsageStats
err = rows.Scan(
&stat.ID, &stat.UserID, &stat.APIKeyID, &stat.UsageDate, &stat.UsageHour,
&stat.ServiceName, &stat.Provider, &stat.Model, &stat.UsageType,
&stat.PromptTokens, &stat.CompletionTokens, &stat.TotalTokens, &stat.RequestsMade,
&stat.CreatedAt, &stat.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user usage stats")
}
stats = append(stats, &stat)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating user usage stats")
}
return stats, nil
}
// GetUserAITokenUsageStatsByDay returns daily aggregated AI token usage for a user
func (s *UsageStatsService) GetUserAITokenUsageStatsByDay(ctx context.Context, userID int, startDate, endDate time.Time) (stats []*UserUsageStatsDaily, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_user_ai_token_usage_stats_by_day",
attribute.Int("user_id", userID),
attribute.String("start_date", startDate.Format("2006-01-02")),
attribute.String("end_date", endDate.Format("2006-01-02")),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT usage_date, service_name, provider, model, usage_type,
SUM(prompt_tokens) as total_prompt_tokens,
SUM(completion_tokens) as total_completion_tokens,
SUM(total_tokens) as total_tokens,
SUM(requests_made) as total_requests
FROM user_usage_stats
WHERE user_id = $1 AND usage_date >= $2 AND usage_date <= $3
GROUP BY usage_date, service_name, provider, model, usage_type
ORDER BY usage_date DESC, service_name, provider, model, usage_type`
rows, err := s.db.QueryContext(ctx, query, userID, startDate, endDate)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query user daily usage stats")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close user daily usage stats query", map[string]interface{}{
"error": closeErr.Error(),
})
}
}()
stats = []*UserUsageStatsDaily{}
for rows.Next() {
var stat UserUsageStatsDaily
var usageDate time.Time
err = rows.Scan(
&usageDate, &stat.ServiceName, &stat.Provider, &stat.Model, &stat.UsageType,
&stat.TotalPromptTokens, &stat.TotalCompletionTokens, &stat.TotalTokens, &stat.TotalRequests,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user daily usage stats")
}
stat.UsageDate = openapi_types.Date{Time: usageDate}
stats = append(stats, &stat)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating user daily usage stats")
}
return stats, nil
}
// GetUserAITokenUsageStatsByHour returns hourly aggregated AI token usage for a user on a specific day
func (s *UsageStatsService) GetUserAITokenUsageStatsByHour(ctx context.Context, userID int, date time.Time) (stats []*UserUsageStatsHourly, err error) {
ctx, span := observability.TraceUsageStatsFunction(ctx, "get_user_ai_token_usage_stats_by_hour",
attribute.Int("user_id", userID),
attribute.String("date", date.Format("2006-01-02")),
)
defer observability.FinishSpan(span, &err)
startOfDay := date.Truncate(24 * time.Hour)
endOfDay := startOfDay.Add(24 * time.Hour).Add(-time.Nanosecond)
query := `
SELECT usage_hour, service_name, provider, model, usage_type,
SUM(prompt_tokens) as total_prompt_tokens,
SUM(completion_tokens) as total_completion_tokens,
SUM(total_tokens) as total_tokens,
SUM(requests_made) as total_requests
FROM user_usage_stats
WHERE user_id = $1 AND usage_date >= $2 AND usage_date <= $3
GROUP BY usage_hour, service_name, provider, model, usage_type
ORDER BY usage_hour, service_name, provider, model, usage_type`
rows, err := s.db.QueryContext(ctx, query, userID, startOfDay, endOfDay)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query user hourly usage stats")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close user hourly usage stats query", map[string]interface{}{
"error": closeErr.Error(),
})
}
}()
stats = []*UserUsageStatsHourly{}
for rows.Next() {
var stat UserUsageStatsHourly
err = rows.Scan(
&stat.UsageHour, &stat.ServiceName, &stat.Provider, &stat.Model, &stat.UsageType,
&stat.TotalPromptTokens, &stat.TotalCompletionTokens, &stat.TotalTokens, &stat.TotalRequests,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user hourly usage stats")
}
stats = append(stats, &stat)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating user hourly usage stats")
}
return stats, nil
}
// NoopUsageStatsService is a no-operation implementation for testing and when quotas are disabled
type NoopUsageStatsService struct{}
// NewNoopUsageStatsService creates a new noop usage stats service
func NewNoopUsageStatsService() *NoopUsageStatsService {
return &NoopUsageStatsService{}
}
// CheckQuota always returns nil (no quota checking)
func (s *NoopUsageStatsService) CheckQuota(_ context.Context, _, _ string, _ int) (err error) {
return nil
}
// RecordUsage always returns nil (no usage recording)
func (s *NoopUsageStatsService) RecordUsage(_ context.Context, _, _ string, _, _ int) (err error) {
return nil
}
// GetCurrentMonthUsage returns empty stats
func (s *NoopUsageStatsService) GetCurrentMonthUsage(_ context.Context, _, _ string) (stats *UsageStats, err error) {
return &UsageStats{
ServiceName: "",
UsageType: "",
CharactersUsed: 0,
RequestsMade: 0,
}, nil
}
// GetMonthlyQuota always returns 0 (no quota limit)
func (s *NoopUsageStatsService) GetMonthlyQuota(_ string) int64 {
return 0
}
// GetAllUsageStats returns all usage statistics (for admin interface)
func (s *NoopUsageStatsService) GetAllUsageStats(_ context.Context) ([]*UsageStats, error) {
return []*UsageStats{}, nil
}
// GetUsageStatsByService returns usage statistics for a specific service
func (s *NoopUsageStatsService) GetUsageStatsByService(_ context.Context, _ string) ([]*UsageStats, error) {
return []*UsageStats{}, nil
}
// GetUsageStatsByMonth returns usage statistics for a specific month
func (s *NoopUsageStatsService) GetUsageStatsByMonth(_ context.Context, _, _ int) ([]*UsageStats, error) {
return []*UsageStats{}, nil
}
// RecordUserAITokenUsage always returns nil (no usage recording)
func (s *NoopUsageStatsService) RecordUserAITokenUsage(_ context.Context, _ int, _ *int, _, _, _ string, _, _, _, _ int) error {
return nil
}
// GetUserAITokenUsageStats returns empty stats
func (s *NoopUsageStatsService) GetUserAITokenUsageStats(_ context.Context, _ int, _, _ time.Time) ([]*UserUsageStats, error) {
return []*UserUsageStats{}, nil
}
// GetUserAITokenUsageStatsByDay returns empty stats
func (s *NoopUsageStatsService) GetUserAITokenUsageStatsByDay(_ context.Context, _ int, _, _ time.Time) ([]*UserUsageStatsDaily, error) {
return []*UserUsageStatsDaily{}, nil
}
// GetUserAITokenUsageStatsByHour returns empty stats
func (s *NoopUsageStatsService) GetUserAITokenUsageStatsByHour(_ context.Context, _ int, _ time.Time) ([]*UserUsageStatsHourly, error) {
return []*UserUsageStatsHourly{}, nil
}
package services
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/lib/pq"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/bcrypt"
)
// UserServiceInterface defines the interface for user-related operations.
// This allows for easier mocking in tests.
type UserServiceInterface interface {
CreateUserWithPassword(ctx context.Context, username, password, language, level string) (*models.User, error)
CreateUserWithEmailAndTimezone(ctx context.Context, username, email, timezone, language, level string) (*models.User, error)
GetUserByID(ctx context.Context, id int) (*models.User, error)
GetUserByUsername(ctx context.Context, username string) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
AuthenticateUser(ctx context.Context, username, password string) (*models.User, error)
UpdateUserSettings(ctx context.Context, userID int, settings *models.UserSettings) error
UpdateUserProfile(ctx context.Context, userID int, username, email, timezone string) error
UpdateUserPassword(ctx context.Context, userID int, newPassword string) error
UpdateLastActive(ctx context.Context, userID int) error
GetAllUsers(ctx context.Context) ([]models.User, error)
GetUsersPaginated(ctx context.Context, page, pageSize int, search, language, level, aiProvider, aiModel, aiEnabled, active string) ([]models.User, int, error)
DeleteUser(ctx context.Context, userID int) error
DeleteAllUsers(ctx context.Context) error
EnsureAdminUserExists(ctx context.Context, adminUsername, adminPassword string) error
ResetDatabase(ctx context.Context) error
ClearUserData(ctx context.Context) error
ClearUserDataForUser(ctx context.Context, userID int) error
GetUserAPIKey(ctx context.Context, userID int, provider string) (string, error)
GetUserAPIKeyWithID(ctx context.Context, userID int, provider string) (string, *int, error)
SetUserAPIKey(ctx context.Context, userID int, provider, apiKey string) error
HasUserAPIKey(ctx context.Context, userID int, provider string) (bool, error)
// Role management methods
GetUserRoles(ctx context.Context, userID int) ([]models.Role, error)
GetAllRoles(ctx context.Context) ([]models.Role, error)
AssignRole(ctx context.Context, userID, roleID int) error
AssignRoleByName(ctx context.Context, userID int, roleName string) error
RemoveRole(ctx context.Context, userID, roleID int) error
HasRole(ctx context.Context, userID int, roleName string) (bool, error)
IsAdmin(ctx context.Context, userID int) (bool, error)
GetDB() *sql.DB
UpdateWordOfDayEmailEnabled(ctx context.Context, userID int, enabled bool) error
}
// UserService provides methods for user management.
type UserService struct {
db *sql.DB
cfg *config.Config
logger *observability.Logger
}
// Shared query constants to eliminate duplication
const (
// userSelectFields contains all user fields for SELECT queries
userSelectFields = `id, username, email, timezone, password_hash, last_active, preferred_language, current_level, ai_provider, ai_model, ai_enabled, ai_api_key, word_of_day_email_enabled, created_at, updated_at`
// userSelectFieldsNoPassword contains user fields excluding password_hash for GetAllUsers
userSelectFieldsNoPassword = `id, username, email, timezone, last_active, preferred_language, current_level, ai_provider, ai_model, ai_enabled, ai_api_key, word_of_day_email_enabled, created_at, updated_at`
)
// scanUserFromRow scans a database row into a models.User struct
func (s *UserService) scanUserFromRow(row *sql.Row) (result0 *models.User, err error) {
user := &models.User{}
err = row.Scan(
&user.ID, &user.Username, &user.Email, &user.Timezone, &user.PasswordHash, &user.LastActive,
&user.PreferredLanguage, &user.CurrentLevel, &user.AIProvider,
&user.AIModel, &user.AIEnabled, &user.AIAPIKey, &user.WordOfDayEmailEnabled, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
// scanUserFromRowsNoPassword scans a database rows into a models.User struct (without password_hash)
func (s *UserService) scanUserFromRowsNoPassword(rows *sql.Rows) (result0 *models.User, err error) {
user := &models.User{}
err = rows.Scan(
&user.ID, &user.Username, &user.Email, &user.Timezone, &user.LastActive,
&user.PreferredLanguage, &user.CurrentLevel, &user.AIProvider,
&user.AIModel, &user.AIEnabled, &user.AIAPIKey, &user.WordOfDayEmailEnabled, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
// getUserByQuery is a shared method for getting a user by any query
func (s *UserService) getUserByQuery(ctx context.Context, query string, args ...interface{}) (result0 *models.User, err error) {
row := s.db.QueryRowContext(ctx, query, args...)
var user *models.User
user, err = s.scanUserFromRow(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // User not found is not an error here
}
return nil, err
}
// Try to apply default settings, but don't fail if there's an issue
s.applyDefaultSettings(ctx, user)
return user, nil
}
// NewUserServiceWithLogger creates a new UserService instance with logger
func NewUserServiceWithLogger(db *sql.DB, cfg *config.Config, logger *observability.Logger) *UserService {
return &UserService{
db: db,
cfg: cfg,
logger: logger,
}
}
// CreateUser creates a new user with the specified username, language, and level
// Only used for testing purposes, should be moved to test utils if possible.
func (s *UserService) CreateUser(ctx context.Context, username, language, level string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_user", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Validate username is not empty
if username == "" || len(strings.TrimSpace(username)) == 0 {
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "username cannot be empty")
}
// default timezone to UTC for new users
query := `INSERT INTO users (username, preferred_language, current_level, last_active, created_at, updated_at, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id`
now := time.Now()
var id int
err = s.db.QueryRowContext(ctx, query, username, language, level, now, now, now, "UTC").Scan(&id)
if err != nil {
return nil, err
}
var user *models.User
user, err = s.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
if user == nil {
return nil, contextutils.WrapError(contextutils.ErrDatabaseQuery, "user was created but could not be retrieved from database")
}
return user, nil
}
// CreateUserWithEmailAndTimezone creates a new user with email and timezone
func (s *UserService) CreateUserWithEmailAndTimezone(ctx context.Context, username, email, timezone, language, level string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_user_with_email", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Validate username is not empty
if username == "" || len(strings.TrimSpace(username)) == 0 {
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "username cannot be empty")
}
query := `INSERT INTO users (username, email, timezone, preferred_language, current_level, ai_enabled, last_active, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id`
now := time.Now()
var id int
err = s.db.QueryRowContext(ctx, query, username, email, timezone, language, level, false, now, now, now).Scan(&id)
if err != nil {
if isDuplicateKeyError(err) {
return nil, contextutils.ErrRecordExists
}
return nil, err
}
if err != nil {
return nil, err
}
var user *models.User
user, err = s.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
if user == nil {
return nil, contextutils.WrapError(contextutils.ErrDatabaseQuery, "user was created but could not be retrieved from database")
}
// Assign default "user" role to new users
err = s.AssignRoleByName(ctx, user.ID, "user")
if err != nil {
// Log the error but don't fail the user creation
// The user role assignment can be done manually by admin if needed
s.logger.Warn(ctx, "Failed to assign default user role", map[string]interface{}{
"user_id": user.ID,
"error": err.Error(),
})
}
return user, nil
}
// CreateUserWithPassword creates a new user with password authentication
func (s *UserService) CreateUserWithPassword(ctx context.Context, username, password, language, level string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_user_with_password", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Validate username is not empty
if username == "" || len(strings.TrimSpace(username)) == 0 {
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "username cannot be empty")
}
// Hash the password using bcrypt
var hashedPassword []byte
hashedPassword, err = bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
// default timezone to UTC for new users created with password
query := `INSERT INTO users (username, password_hash, preferred_language, current_level, ai_enabled, last_active, created_at, updated_at, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id`
now := time.Now()
var id int
err = s.db.QueryRowContext(ctx, query, username, string(hashedPassword), language, level, false, now, now, now, "UTC").Scan(&id)
if err != nil {
if isDuplicateKeyError(err) {
return nil, contextutils.ErrRecordExists
}
return nil, err
}
if err != nil {
return nil, err
}
user, err := s.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
if user == nil {
return nil, contextutils.WrapError(contextutils.ErrDatabaseQuery, "user was created but could not be retrieved from database")
}
// Assign default "user" role to new users
err = s.AssignRoleByName(ctx, user.ID, "user")
if err != nil {
// Log the error but don't fail the user creation
// The user role assignment can be done manually by admin if needed
s.logger.Warn(ctx, "Failed to assign default user role", map[string]interface{}{
"user_id": user.ID,
"error": err.Error(),
})
}
return user, nil
}
// AuthenticateUser verifies user credentials and returns the user if valid
func (s *UserService) AuthenticateUser(ctx context.Context, username, password string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "authenticate_user", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Get user by username
var user *models.User
user, err = s.GetUserByUsername(ctx, username)
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("user not found")
}
// Check if password hash exists
if !user.PasswordHash.Valid {
return nil, errors.New("user has no password set")
}
// Compare provided password with stored hash
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash.String), []byte(password))
if err != nil {
return nil, errors.New("invalid password")
}
return user, nil
}
// GetUserByID retrieves a user by their ID
func (s *UserService) GetUserByID(ctx context.Context, id int) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_by_id", attribute.Int("user.id", id))
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users WHERE id = $1", userSelectFields)
var user *models.User
user, err = s.getUserByQuery(ctx, query, id)
if err != nil {
s.logger.Error(ctx, "Database error retrieving user", err, map[string]interface{}{"user_id": id})
return nil, err
}
if user == nil {
s.logger.Debug(ctx, "User not found in database", map[string]interface{}{"user_id": id})
return nil, nil
}
// Load user roles
roles, err := s.GetUserRoles(ctx, id)
if err != nil {
s.logger.Warn(ctx, "Failed to load user roles", map[string]interface{}{"user_id": id, "error": err.Error()})
// Don't fail the entire request if roles can't be loaded
user.Roles = []models.Role{}
} else {
user.Roles = roles
}
return user, nil
}
// GetUserByUsername retrieves a user by their username
func (s *UserService) GetUserByUsername(ctx context.Context, username string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_by_username", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users WHERE username = $1", userSelectFields)
return s.getUserByQuery(ctx, query, username)
}
// UpdateUserSettings updates user settings including AI configuration
func (s *UserService) UpdateUserSettings(ctx context.Context, userID int, settings *models.UserSettings) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_user_settings", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
// Check if user exists before updating settings
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
// Start a transaction to update both user settings and API key
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for user settings update")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Handle AI enabled logic
aiProvider := settings.AIProvider
aiModel := settings.AIModel
// If AI is disabled, clear the provider and model
if !settings.AIEnabled {
aiProvider = ""
aiModel = ""
}
// Update user settings (excluding API key which is now stored separately)
query := `UPDATE users SET preferred_language = $1, current_level = $2, ai_provider = $3, ai_model = $4, ai_enabled = $5, updated_at = $6 WHERE id = $7`
var result sql.Result
result, err = tx.ExecContext(ctx, query, settings.Language, settings.Level, aiProvider, aiModel, settings.AIEnabled, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user settings in transaction")
}
// Check if the user was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user with ID %d not found", userID)
}
// If an API key is provided and AI is enabled, save it for the specific provider
if settings.AIAPIKey != "" && settings.AIProvider != "" && settings.AIEnabled {
err = s.setUserAPIKeyTx(ctx, tx, userID, settings.AIProvider, settings.AIAPIKey)
if err != nil {
return contextutils.WrapError(err, "failed to set user API key in transaction")
}
}
return tx.Commit()
}
// UpdateWordOfDayEmailEnabled updates the user's preference for word-of-day emails
func (s *UserService) UpdateWordOfDayEmailEnabled(ctx context.Context, userID int, enabled bool) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_word_of_day_email_enabled",
attribute.Int("user.id", userID),
attribute.Bool("word_of_day_email_enabled", enabled),
)
defer observability.FinishSpan(span, &err)
// Ensure user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.ErrRecordNotFound
}
_, err = s.db.ExecContext(ctx, `UPDATE users SET word_of_day_email_enabled = $1, updated_at = NOW() WHERE id = $2`, enabled, userID)
if err != nil {
return contextutils.WrapError(err, "failed to update word_of_day_email_enabled")
}
return nil
}
// GetUserAPIKey retrieves the API key for a specific provider for a user
func (s *UserService) GetUserAPIKey(ctx context.Context, userID int, provider string) (result0 string, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_api_key", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
// Check if user exists before getting API key
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return "", contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return "", contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
query := `SELECT api_key FROM user_api_keys WHERE user_id = $1 AND provider = $2`
var apiKey string
err = s.db.QueryRowContext(ctx, query, userID, provider).Scan(&apiKey)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", contextutils.WrapError(contextutils.ErrRecordNotFound, "API key for provider not found")
}
return "", contextutils.WrapError(err, "failed to get user API key")
}
return apiKey, nil
}
// GetUserAPIKeyWithID retrieves the API key and its ID for a specific provider for a user
func (s *UserService) GetUserAPIKeyWithID(ctx context.Context, userID int, provider string) (apiKey string, apiKeyID *int, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_api_key_with_id", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
// Check if user exists before getting API key
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return "", nil, contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return "", nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
query := `SELECT id, api_key FROM user_api_keys WHERE user_id = $1 AND provider = $2`
var id int
var key string
err = s.db.QueryRowContext(ctx, query, userID, provider).Scan(&id, &key)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "API key for provider not found")
}
return "", nil, contextutils.WrapError(err, "failed to get user API key with ID")
}
return key, &id, nil
}
// SetUserAPIKey sets the API key for a specific provider for a user
func (s *UserService) SetUserAPIKey(ctx context.Context, userID int, provider, apiKey string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "set_user_api_key", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
// Check if user exists before setting API key
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for API key update")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
err = s.setUserAPIKeyTx(ctx, tx, userID, provider, apiKey)
if err != nil {
return contextutils.WrapError(err, "failed to set user API key in transaction")
}
commitErr := tx.Commit()
if commitErr != nil {
return contextutils.WrapError(commitErr, "failed to commit API key transaction")
}
// Clear the error so defer doesn't try to rollback
err = nil
return nil
}
// setUserAPIKeyTx sets the API key for a specific provider within a transaction
func (s *UserService) setUserAPIKeyTx(ctx context.Context, tx *sql.Tx, userID int, provider, apiKey string) error {
query := `INSERT INTO user_api_keys (user_id, provider, api_key, updated_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, provider)
DO UPDATE SET api_key = $3, updated_at = $4`
_, err := tx.ExecContext(ctx, query, userID, provider, apiKey, time.Now())
return contextutils.WrapError(err, "failed to execute API key transaction")
}
// HasUserAPIKey checks if a user has an API key for a specific provider
func (s *UserService) HasUserAPIKey(ctx context.Context, userID int, provider string) (result0 bool, err error) {
ctx, span := observability.TraceUserFunction(ctx, "has_user_api_key", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
var apiKey string
apiKey, err = s.GetUserAPIKey(ctx, userID, provider)
if err != nil {
// If the error is "not found" and it's specifically about the API key not existing (not the user),
// then it means no API key exists, which is not an error
if errors.Is(err, contextutils.ErrRecordNotFound) {
// Check if the error message indicates it's about the API key, not the user
if strings.Contains(err.Error(), "API key for provider not found") {
return false, nil
}
// If it's about the user not found, return the error
return false, err
}
return false, contextutils.WrapError(err, "failed to check if user has API key")
}
return apiKey != "", nil
}
// UpdateLastActive updates the user's last activity timestamp
func (s *UserService) UpdateLastActive(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_last_active", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
query := `UPDATE users SET last_active = $1 WHERE id = $2`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user last active timestamp")
}
// Check if the user was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user with ID %d not found", userID)
}
return nil
}
// GetAllUsers retrieves all users from the database
func (s *UserService) GetAllUsers(ctx context.Context) (result0 []models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_all_users")
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users", userSelectFieldsNoPassword)
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query all users")
}
defer func() {
if err = rows.Close(); err != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var users []models.User
for rows.Next() {
user, err := s.scanUserFromRowsNoPassword(rows)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user from rows")
}
// Load user roles
roles, err := s.GetUserRoles(ctx, user.ID)
if err != nil {
s.logger.Warn(ctx, "Failed to load user roles", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
// Don't fail the entire request if roles can't be loaded
user.Roles = []models.Role{}
} else {
user.Roles = roles
}
users = append(users, *user)
}
return users, nil
}
// GetUsersPaginated retrieves paginated users with filtering and search
func (s *UserService) GetUsersPaginated(ctx context.Context, page, pageSize int, search, language, level, aiProvider, aiModel, aiEnabled, active string) (result0 []models.User, result1 int, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_users_paginated")
defer observability.FinishSpan(span, &err)
// Build WHERE clause and args
var conditions []string
var args []interface{}
argIndex := 1
// Search filter
if search != "" {
conditions = append(conditions, fmt.Sprintf("(username ILIKE $%d OR email ILIKE $%d)", argIndex, argIndex))
args = append(args, "%"+search+"%")
argIndex++
}
// Language filter
if language != "" {
conditions = append(conditions, fmt.Sprintf("preferred_language = $%d", argIndex))
args = append(args, language)
argIndex++
}
// Level filter
if level != "" {
conditions = append(conditions, fmt.Sprintf("current_level = $%d", argIndex))
args = append(args, level)
argIndex++
}
// AI Provider filter
if aiProvider != "" {
conditions = append(conditions, fmt.Sprintf("ai_provider = $%d", argIndex))
args = append(args, aiProvider)
argIndex++
}
// AI Model filter
if aiModel != "" {
conditions = append(conditions, fmt.Sprintf("ai_model = $%d", argIndex))
args = append(args, aiModel)
argIndex++
}
// AI Enabled filter
if aiEnabled != "" {
enabled := aiEnabled == "true"
conditions = append(conditions, fmt.Sprintf("ai_enabled = $%d", argIndex))
args = append(args, enabled)
argIndex++
}
// Active filter (based on last_active within 7 days)
if active != "" {
activeThreshold := time.Now().AddDate(0, 0, -7)
switch active {
case "true":
conditions = append(conditions, fmt.Sprintf("last_active >= $%d", argIndex))
args = append(args, activeThreshold)
case "false":
conditions = append(conditions, fmt.Sprintf("(last_active < $%d OR last_active IS NULL)", argIndex))
args = append(args, activeThreshold)
}
argIndex++
}
// Build WHERE clause
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM users %s", whereClause)
var total int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count users")
}
// Get paginated results
offset := (page - 1) * pageSize
query := fmt.Sprintf("SELECT %s FROM users %s ORDER BY username LIMIT $%d OFFSET $%d",
userSelectFieldsNoPassword, whereClause, argIndex, argIndex+1)
args = append(args, pageSize, offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to query paginated users")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var users []models.User
for rows.Next() {
user, err := s.scanUserFromRowsNoPassword(rows)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to scan user from rows")
}
// Load user roles
roles, err := s.GetUserRoles(ctx, user.ID)
if err != nil {
s.logger.Warn(ctx, "Failed to load user roles", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
// Don't fail the entire request if roles can't be loaded
user.Roles = []models.Role{}
} else {
user.Roles = roles
}
users = append(users, *user)
}
return users, total, nil
}
// GetUserByEmail retrieves a user by their email address
func (s *UserService) GetUserByEmail(ctx context.Context, email string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_by_email", attribute.String("user.email", email))
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users WHERE email = $1", userSelectFields)
return s.getUserByQuery(ctx, query, email)
}
// UpdateUserProfile updates user profile information (username, email, timezone)
func (s *UserService) UpdateUserProfile(ctx context.Context, userID int, username, email, timezone string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_user_profile", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
query := `UPDATE users SET username = $1, email = $2, timezone = $3, updated_at = $4 WHERE id = $5`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, username, email, timezone, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user profile")
}
// Check if the user was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user with ID %d not found", userID)
}
return nil
}
// UpdateUserPassword updates a user's password
func (s *UserService) UpdateUserPassword(ctx context.Context, userID int, newPassword string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_user_password", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
// Validate password is not empty
if newPassword == "" {
return contextutils.ErrorWithContextf("password cannot be empty")
}
// Check if user exists first
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
// Hash the new password using bcrypt
var hashedPassword []byte
hashedPassword, err = bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return contextutils.WrapError(err, "failed to hash password")
}
query := `UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`
result, err := s.db.ExecContext(ctx, query, string(hashedPassword), time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user password")
}
// Check if any rows were affected
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
s.logger.Info(ctx, "Password updated successfully", map[string]interface{}{"user_id": userID, "username": user.Username})
return nil
}
// DeleteUser removes a user and their associated data
func (s *UserService) DeleteUser(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_user", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
// Check if user exists before deleting
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
// Best-effort cleanup of dependent rows for tables that may not have ON DELETE CASCADE in some environments
// This keeps tests deterministic and avoids orphaned data
// TODO: This is a hack to make the tests deterministic. We should use ON DELETE CASCADE instead.
cleanupQueries := []string{
`DELETE FROM question_reports WHERE reported_by_user_id = $1`,
`DELETE FROM user_api_keys WHERE user_id = $1`,
`DELETE FROM user_roles WHERE user_id = $1`,
`DELETE FROM user_learning_preferences WHERE user_id = $1`,
`DELETE FROM question_priority_scores WHERE user_id = $1`,
`DELETE FROM user_question_metadata WHERE user_id = $1`,
`DELETE FROM user_responses WHERE user_id = $1`,
`DELETE FROM user_questions WHERE user_id = $1`,
}
for _, q := range cleanupQueries {
if _, err := s.db.ExecContext(ctx, q, userID); err != nil {
s.logger.Warn(ctx, "Non-fatal cleanup failure during user delete", map[string]interface{}{"error": err.Error(), "query": q, "user_id": userID})
}
}
// Delete the user
query := `DELETE FROM users WHERE id = $1`
result, err := s.db.ExecContext(ctx, query, userID)
if err != nil {
return contextutils.WrapError(err, "failed to delete user")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
s.logger.Info(ctx, "User %d deleted successfully", map[string]interface{}{"user_id": userID})
return nil
}
// DeleteAllUsers removes all users from the database
func (s *UserService) DeleteAllUsers(ctx context.Context) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_all_users")
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for delete all users")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Whitelist of valid table names to prevent SQL injection
validTables := map[string]bool{
"user_responses": true,
"performance_metrics": true,
"users": true,
}
// Delete all data in the correct order (to respect foreign key constraints)
tables := []string{
"user_responses",
"performance_metrics",
"users",
}
for _, table := range tables {
// Validate table name against whitelist
if !validTables[table] {
return contextutils.ErrorWithContextf("invalid table name: %s", table)
}
// Use parameterized query with validated table name
query := fmt.Sprintf("DELETE FROM %s", table)
if _, err := tx.ExecContext(ctx, query); err != nil {
return contextutils.WrapErrorf(err, "failed to delete from table %s", table)
}
// Reset sequence for PostgreSQL
sequenceQuery := fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", table)
if _, err := tx.ExecContext(ctx, sequenceQuery); err != nil {
// This might fail if the table doesn't have a sequence, so we log but don't fail
s.logger.Warn(ctx, "Note: Could not reset sequence for %s (this is normal for some tables)", map[string]interface{}{"table": table})
}
}
return contextutils.WrapError(tx.Commit(), "failed to commit delete all users transaction")
}
// EnsureAdminUserExists creates the admin user if it doesn't exist
func (s *UserService) EnsureAdminUserExists(ctx context.Context, adminUsername, adminPassword string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "ensure_admin_user_exists", attribute.String("admin.username", adminUsername))
defer observability.FinishSpan(span, &err)
// Validate input parameters
if adminUsername == "" {
return contextutils.ErrorWithContextf("admin username cannot be empty")
}
if adminPassword == "" {
return contextutils.ErrorWithContextf("admin password cannot be empty")
}
// Check if admin user already exists
var existingUser *models.User
existingUser, err = s.GetUserByUsername(ctx, adminUsername)
if err != nil {
return contextutils.WrapError(err, "failed to check if admin user exists")
}
if existingUser != nil {
// User exists, check if password needs to be updated
if existingUser.PasswordHash.Valid {
// User has a password, test if it matches current admin password
err = bcrypt.CompareHashAndPassword([]byte(existingUser.PasswordHash.String), []byte(adminPassword))
if err == nil {
// Password matches, ensure AI settings are configured
err = s.ensureAdminAISettings(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to set AI settings for existing admin user", map[string]interface{}{"error": err.Error()})
}
// Ensure admin user has email and timezone if not set
if !existingUser.Email.Valid || !existingUser.Timezone.Valid {
err = s.ensureAdminProfile(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to update admin profile", map[string]interface{}{"error": err.Error()})
}
}
// Ensure admin user has admin role
isAdmin, err := s.IsAdmin(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to check admin role for existing admin user", map[string]interface{}{"error": err.Error()})
} else if !isAdmin {
err = s.AssignRoleByName(ctx, existingUser.ID, "admin")
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to assign admin role to existing admin user", map[string]interface{}{"error": err.Error()})
}
}
s.logger.Info(ctx, "Admin user already exists with correct password", map[string]interface{}{"username": adminUsername})
return nil
}
}
// Update password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
if err != nil {
return contextutils.WrapError(err, "failed to hash admin password")
}
query := `UPDATE users SET password_hash = $1, updated_at = $2 WHERE username = $3`
_, err = s.db.ExecContext(ctx, query, string(hashedPassword), time.Now(), adminUsername)
if err != nil {
return contextutils.WrapError(err, "failed to update admin user password")
}
// Ensure AI settings are configured
err = s.ensureAdminAISettings(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to set AI settings for existing admin user", map[string]interface{}{"error": err.Error()})
}
// Ensure admin user has email and timezone if not set
if !existingUser.Email.Valid || !existingUser.Timezone.Valid {
err = s.ensureAdminProfile(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to update admin profile", map[string]interface{}{"error": err.Error()})
}
}
// Ensure admin user has admin role
isAdmin, err := s.IsAdmin(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to check admin role for existing admin user", map[string]interface{}{"error": err.Error()})
} else if !isAdmin {
err = s.AssignRoleByName(ctx, existingUser.ID, "admin")
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to assign admin role to existing admin user", map[string]interface{}{"error": err.Error()})
}
}
s.logger.Info(ctx, "Updated password for admin user", map[string]interface{}{"username": adminUsername})
return nil
}
// Create new admin user with email and timezone
user, err := s.CreateUserWithEmailAndTimezone(ctx, adminUsername, "admin@example.com", "America/New_York", "italian", "A1")
if err != nil {
return contextutils.WrapError(err, "failed to create admin user")
}
// Set password for the admin user
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
if err != nil {
return contextutils.WrapError(err, "failed to hash new admin password")
}
query := `UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`
_, err = s.db.ExecContext(ctx, query, string(hashedPassword), time.Now(), user.ID)
if err != nil {
return contextutils.WrapError(err, "failed to set password for new admin user")
}
// Set up AI settings for the admin user
err = s.ensureAdminAISettings(ctx, user.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to set AI settings for new admin user", map[string]interface{}{"error": err.Error()})
}
// Assign admin role to the admin user
err = s.AssignRoleByName(ctx, user.ID, "admin")
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to assign admin role to new admin user", map[string]interface{}{"error": err.Error()})
}
s.logger.Info(ctx, "Created admin user", map[string]interface{}{"username": adminUsername})
return nil
}
// ensureAdminAISettings ensures the admin user has AI settings configured
// Only sets default values if the user doesn't already have AI settings configured
func (s *UserService) ensureAdminAISettings(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "ensure_admin_ai_settings", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
var user *models.User
user, err = s.GetUserByID(ctx, userID)
if err != nil {
return err
}
if user == nil {
return errors.New("admin user not found")
}
// If user already has AI provider configured, don't override their settings
if user.AIProvider.Valid && user.AIProvider.String != "" {
s.logger.Info(ctx, "User ID already has AI settings configured, preserving existing settings", map[string]interface{}{"user_id": userID, "provider": user.AIProvider.String})
return nil
}
// Set default AI settings with a default API key
settings := &models.UserSettings{
AIProvider: "ollama",
AIModel: "llama4:latest",
AIAPIKey: "not_needed", // Default API key
}
// Only update AI settings, preserve other user settings
query := `UPDATE users SET ai_provider = $1, ai_model = $2, ai_api_key = $3, updated_at = $4 WHERE id = $5`
_, err = s.db.ExecContext(ctx, query, settings.AIProvider, settings.AIModel, settings.AIAPIKey, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user AI settings")
}
// Save the API key to the user_api_keys table
err = s.SetUserAPIKey(ctx, userID, settings.AIProvider, settings.AIAPIKey)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to save API key for user %d", map[string]interface{}{"user_id": userID, "error": err.Error()})
}
s.logger.Info(ctx, "Set default AI settings for user", map[string]interface{}{"user_id": userID, "provider": settings.AIProvider, "model": settings.AIModel})
return nil
}
// ensureAdminProfile ensures the admin user has email and timezone set
func (s *UserService) ensureAdminProfile(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "ensure_admin_profile", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
query := `UPDATE users SET email = $1, timezone = $2, updated_at = $3 WHERE id = $4 AND (email IS NULL OR timezone IS NULL)`
_, err = s.db.ExecContext(ctx, query, "admin@example.com", "America/New_York", time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update admin profile")
}
s.logger.Info(ctx, "Updated admin user profile with default email and timezone", map[string]interface{}{"user_id": userID})
return nil
}
// ResetDatabase completely resets the database to an empty state
func (s *UserService) ResetDatabase(ctx context.Context) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "reset_database")
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for database reset")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Whitelist of valid table names to prevent SQL injection
validTables := map[string]bool{
"user_responses": true,
"performance_metrics": true,
"questions": true,
"users": true,
}
// Delete all data in the correct order (to respect foreign key constraints)
tables := []string{
"user_responses",
"performance_metrics",
"questions",
"users",
}
for _, table := range tables {
// Validate table name against whitelist
if !validTables[table] {
return contextutils.ErrorWithContextf("invalid table name: %s", table)
}
// Use parameterized query with validated table name
query := fmt.Sprintf("DELETE FROM %s", table)
if _, err := tx.ExecContext(ctx, query); err != nil {
return contextutils.WrapErrorf(err, "failed to delete from table %s during reset", table)
}
s.logger.Info(ctx, "Cleared table: %s", map[string]interface{}{"table": table})
// Reset sequence for PostgreSQL
sequenceQuery := fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", table)
if _, err := tx.ExecContext(ctx, sequenceQuery); err != nil {
// This might fail if the table doesn't have a sequence, so we log but don't fail
s.logger.Warn(ctx, "Note: Could not reset sequence for %s (this is normal for some tables)", map[string]interface{}{"table": table})
}
}
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit database reset transaction")
}
s.logger.Info(ctx, "Database reset completed successfully")
return nil
}
// ClearUserData removes all user activity data but keeps the users themselves
func (s *UserService) ClearUserData(ctx context.Context) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "clear_user_data")
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for clear user data")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Whitelist of valid table names to prevent SQL injection
validTables := map[string]bool{
"user_responses": true,
"performance_metrics": true,
"questions": true,
}
// Delete user data but keep users (order matters due to foreign key constraints)
tables := []string{
"user_responses",
"performance_metrics",
"questions",
}
for _, table := range tables {
// Validate table name against whitelist
if !validTables[table] {
return contextutils.ErrorWithContextf("invalid table name: %s", table)
}
// Use parameterized query with validated table name
query := fmt.Sprintf("DELETE FROM %s", table)
if _, err := tx.ExecContext(ctx, query); err != nil {
return contextutils.WrapErrorf(err, "failed to delete from table %s during clear user data", table)
}
s.logger.Info(ctx, "Cleared table: %s", map[string]interface{}{"table": table})
// Reset sequence for PostgreSQL
sequenceQuery := fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", table)
if _, err := tx.ExecContext(ctx, sequenceQuery); err != nil {
// This might fail if the table doesn't have a sequence, so we log but don't fail
s.logger.Warn(ctx, "Note: Could not reset sequence for %s (this is normal for some tables)", map[string]interface{}{"table": table})
}
}
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit clear user data transaction")
}
s.logger.Info(ctx, "User data cleared successfully (users preserved)")
return nil
}
// ClearUserDataForUser removes all user activity data for a specific user but keeps the user record
func (s *UserService) ClearUserDataForUser(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "clear_user_data_for_user", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
s.logger.Warn(ctx, "Failed to begin transaction", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to begin transaction for clear user data for specific user")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Delete user_responses for this user's questions (via user_questions)
query := `DELETE FROM user_responses WHERE question_id IN (SELECT question_id FROM user_questions WHERE user_id = $1)`
result, err := tx.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to delete user_responses", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete user responses for specific user")
}
rows, _ := result.RowsAffected()
s.logger.Info(ctx, "Deleted %d user_responses for user %d", map[string]interface{}{"count": rows, "user_id": userID})
// Delete performance_metrics for this user (performance_metrics has user_id, not question_id)
query = `DELETE FROM performance_metrics WHERE user_id = $1`
result, err = tx.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to delete performance_metrics", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete performance metrics for specific user")
}
rows, _ = result.RowsAffected()
s.logger.Info(ctx, "Deleted %d performance_metrics for user %d", map[string]interface{}{"count": rows, "user_id": userID})
// Delete user_questions for this user
query = `DELETE FROM user_questions WHERE user_id = $1`
result, err = tx.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to delete user_questions", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete user questions for specific user")
}
rows, _ = result.RowsAffected()
s.logger.Info(ctx, "Deleted %d user_questions for user %d", map[string]interface{}{"count": rows, "user_id": userID})
// Optionally, delete orphaned questions (not assigned to any user)
query = `DELETE FROM questions WHERE id NOT IN (SELECT question_id FROM user_questions)`
result, err = tx.ExecContext(ctx, query)
if err != nil {
s.logger.Warn(ctx, "Failed to delete orphaned questions", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete orphaned questions")
}
rows, _ = result.RowsAffected()
s.logger.Info(ctx, "Deleted %d orphaned questions", map[string]interface{}{"count": rows})
if err := tx.Commit(); err != nil {
s.logger.Warn(ctx, "Failed to commit transaction", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to commit clear user data for specific user transaction")
}
s.logger.Info(ctx, "User data cleared successfully for user %d (users preserved)", map[string]interface{}{"user_id": userID})
return nil
}
func (s *UserService) applyDefaultSettings(ctx context.Context, user *models.User) {
if user == nil || s.cfg == nil {
return
}
_, span := observability.TraceUserFunction(ctx, "apply_default_settings", attribute.Int("user.id", user.ID))
defer span.End()
if user.AIProvider.String == "" && len(s.cfg.Providers) > 0 {
// Use the first available provider as default
provider := s.cfg.Providers[0]
user.AIProvider.String = provider.Code
// Use first model in the list as default
if len(provider.Models) > 0 {
user.AIModel.String = provider.Models[0].Code
}
}
if user.CurrentLevel.String == "" {
// Set default level based on user's preferred language, or use first available language
language := user.PreferredLanguage.String
if language == "" {
languages := s.cfg.GetLanguages()
if len(languages) > 0 {
language = languages[0]
}
}
if language != "" {
levels := s.cfg.GetLevelsForLanguage(language)
if len(levels) > 0 {
user.CurrentLevel.String = levels[0]
}
}
}
if user.PreferredLanguage.String == "" {
user.PreferredLanguage.String = "english"
}
}
// GetUserRoles retrieves all roles for a user
func (s *UserService) GetUserRoles(ctx context.Context, userID int) (result0 []models.Role, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_roles", attribute.Int("user.id", userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT r.id, r.name, r.description, r.created_at, r.updated_at
FROM roles r
JOIN user_roles ur ON r.id = ur.role_id
WHERE ur.user_id = $1
ORDER BY r.name
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user roles")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var roles []models.Role
for rows.Next() {
var role models.Role
err := rows.Scan(&role.ID, &role.Name, &role.Description, &role.CreatedAt, &role.UpdatedAt)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user role")
}
roles = append(roles, role)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating user roles")
}
return roles, nil
}
// AssignRole assigns a role to a user
func (s *UserService) AssignRole(ctx context.Context, userID, roleID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "assign_role", attribute.Int("user.id", userID), attribute.Int("role.id", roleID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user for role assignment")
}
if user == nil {
return contextutils.ErrorWithContextf("user with ID %d not found", userID)
}
// Check if role exists
var roleName string
err = s.db.QueryRowContext(ctx, "SELECT name FROM roles WHERE id = $1", roleID).Scan(&roleName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("role with ID %d not found", roleID)
}
return contextutils.WrapError(err, "failed to check role existence")
}
// Assign role (using ON CONFLICT DO NOTHING to handle duplicate assignments gracefully)
query := `INSERT INTO user_roles (user_id, role_id, created_at) VALUES ($1, $2, $3) ON CONFLICT (user_id, role_id) DO NOTHING`
_, err = s.db.ExecContext(ctx, query, userID, roleID, time.Now())
if err != nil {
return contextutils.WrapError(err, "failed to assign role to user")
}
s.logger.Info(ctx, "Role assigned successfully", map[string]interface{}{
"user_id": userID,
"role_id": roleID,
"role_name": roleName,
})
return nil
}
// AssignRoleByName assigns a role to a user by role name
func (s *UserService) AssignRoleByName(ctx context.Context, userID int, roleName string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "assign_role_by_name", attribute.Int("user.id", userID), attribute.String("role.name", roleName))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user for role assignment")
}
if user == nil {
return contextutils.ErrorWithContextf("user with ID %d not found", userID)
}
// Get role ID by name
var roleID int
err = s.db.QueryRowContext(ctx, "SELECT id FROM roles WHERE name = $1", roleName).Scan(&roleID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("role with name '%s' not found", roleName)
}
return contextutils.WrapError(err, "failed to get role ID by name")
}
// Assign role (using ON CONFLICT DO NOTHING to handle duplicate assignments gracefully)
query := `INSERT INTO user_roles (user_id, role_id, created_at) VALUES ($1, $2, $3) ON CONFLICT (user_id, role_id) DO NOTHING`
_, err = s.db.ExecContext(ctx, query, userID, roleID, time.Now())
if err != nil {
return contextutils.WrapError(err, "failed to assign role to user")
}
s.logger.Info(ctx, "Role assigned successfully", map[string]interface{}{
"user_id": userID,
"role_id": roleID,
"role_name": roleName,
})
return nil
}
// RemoveRole removes a role from a user
func (s *UserService) RemoveRole(ctx context.Context, userID, roleID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "remove_role", attribute.Int("user.id", userID), attribute.Int("role.id", roleID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user for role removal")
}
if user == nil {
return contextutils.ErrorWithContextf("user with ID %d not found", userID)
}
// Check if role exists
var roleName string
err = s.db.QueryRowContext(ctx, "SELECT name FROM roles WHERE id = $1", roleID).Scan(&roleName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("role with ID %d not found", roleID)
}
return contextutils.WrapError(err, "failed to check role existence")
}
// Remove role
query := `DELETE FROM user_roles WHERE user_id = $1 AND role_id = $2`
result, err := s.db.ExecContext(ctx, query, userID, roleID)
if err != nil {
return contextutils.WrapError(err, "failed to remove role from user")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrorWithContextf("user %d does not have role %d", userID, roleID)
}
s.logger.Info(ctx, "Role removed successfully", map[string]interface{}{
"user_id": userID,
"role_id": roleID,
"role_name": roleName,
})
return nil
}
// HasRole checks if a user has a specific role by name
func (s *UserService) HasRole(ctx context.Context, userID int, roleName string) (result0 bool, err error) {
ctx, span := observability.TraceUserFunction(ctx, "has_role", attribute.Int("user.id", userID), attribute.String("role.name", roleName))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(*) > 0
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = $1 AND r.name = $2
`
var hasRole bool
err = s.db.QueryRowContext(ctx, query, userID, roleName).Scan(&hasRole)
if err != nil {
return false, contextutils.WrapError(err, "failed to check if user has role")
}
return hasRole, nil
}
// IsAdmin checks if a user has admin role
func (s *UserService) IsAdmin(ctx context.Context, userID int) (result0 bool, err error) {
ctx, span := observability.TraceUserFunction(ctx, "is_admin", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
return s.HasRole(ctx, userID, "admin")
}
// GetAllRoles returns all available roles in the system
func (s *UserService) GetAllRoles(ctx context.Context) (result0 []models.Role, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_all_roles")
defer observability.FinishSpan(span, &err)
query := `
SELECT id, name, description, created_at, updated_at
FROM roles
ORDER BY name
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get all roles")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var roles []models.Role
for rows.Next() {
var role models.Role
err := rows.Scan(&role.ID, &role.Name, &role.Description, &role.CreatedAt, &role.UpdatedAt)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan role")
}
roles = append(roles, role)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating roles")
}
return roles, nil
}
// GetDB returns the database connection
func (s *UserService) GetDB() *sql.DB {
return s.db
}
// isDuplicateKeyError checks if the error is a duplicate key constraint violation
func isDuplicateKeyError(err error) bool {
if err == nil {
return false
}
// Check for PostgreSQL unique constraint violation error code
if pqErr, ok := err.(*pq.Error); ok {
// PostgreSQL error code 23505 is for unique constraint violations
if pqErr.Code == "23505" {
return true
}
}
return false
}
package services
import (
"context"
"math/rand"
"go.opentelemetry.io/otel/attribute"
"quizapp/internal/config"
"quizapp/internal/observability"
)
// VarietyService handles the selection of variety elements for question generation
type VarietyService struct {
cfg *config.Config
logger *observability.Logger
}
// VarietyElements holds the randomly selected variety elements for a question generation request
type VarietyElements struct {
TopicCategory string
GrammarFocus string
VocabularyDomain string
Scenario string
StyleModifier string
DifficultyModifier string
TimeContext string
}
// NewVarietyServiceWithLogger creates a new VarietyService with logger
func NewVarietyServiceWithLogger(cfg *config.Config, logger *observability.Logger) *VarietyService {
return &VarietyService{
cfg: cfg,
logger: logger,
}
}
// SelectVarietyElements randomly selects variety elements for question generation
// If highPriorityTopics or userWeakAreas are provided, bias topic selection toward those topics first, then gapAnalysis.
func (vs *VarietyService) SelectVarietyElements(ctx context.Context, level string, highPriorityTopics, userWeakAreas []string, gapAnalysis map[string]int) *VarietyElements {
_, span := observability.TraceVarietyFunction(ctx, "select_variety_elements",
attribute.String("variety.level", level),
attribute.Int("variety.high_priority_topics_count", len(highPriorityTopics)),
attribute.Int("variety.user_weak_areas_count", len(userWeakAreas)),
attribute.Int("variety.gap_analysis_count", len(gapAnalysis)),
)
defer span.End()
// Get variety configuration from config
if vs.cfg.Variety != nil {
variety := vs.cfg.Variety
elements := &VarietyElements{}
// Helper function to get weighted selection from gap analysis
getWeightedSelection := func(gapType string, availableOptions []string) string {
if len(gapAnalysis) == 0 || len(availableOptions) == 0 {
return ""
}
var weightedOptions []string
for _, option := range availableOptions {
gapKey := gapType + "_" + option
if count, ok := gapAnalysis[gapKey]; ok && count > 0 {
// Intensify weighting by squaring the severity to reduce randomness sensitivity
weight := count * count
for range weight {
weightedOptions = append(weightedOptions, option)
}
}
}
if len(weightedOptions) > 0 {
return weightedOptions[rand.Intn(len(weightedOptions))]
}
return ""
}
// Define all possible variety elements with their selection functions
type varietySelector struct {
name string
selector func() string
}
var selectors []varietySelector
// Topic category selector (biased by userWeakAreas, highPriorityTopics, then gapAnalysis if provided)
if len(variety.TopicCategories) > 0 {
selectors = append(selectors, varietySelector{
name: "topic_category",
selector: func() string {
// 1. UserWeakAreas
if len(userWeakAreas) > 0 {
var matching []string
for _, topic := range variety.TopicCategories {
for _, weak := range userWeakAreas {
if topic == weak {
matching = append(matching, topic)
}
}
}
if len(matching) > 0 {
elements.TopicCategory = matching[rand.Intn(len(matching))]
return elements.TopicCategory
}
}
// 2. HighPriorityTopics
if len(highPriorityTopics) > 0 {
var matching []string
for _, topic := range variety.TopicCategories {
for _, high := range highPriorityTopics {
if topic == high {
matching = append(matching, topic)
}
}
}
if len(matching) > 0 {
elements.TopicCategory = matching[rand.Intn(len(matching))]
return elements.TopicCategory
}
}
// 3. GapAnalysis for topics
if selected := getWeightedSelection("topic_category", variety.TopicCategories); selected != "" {
elements.TopicCategory = selected
return elements.TopicCategory
}
// Fallback to random
elements.TopicCategory = variety.TopicCategories[rand.Intn(len(variety.TopicCategories))]
return elements.TopicCategory
},
})
}
// Grammar focus selector (now with gap analysis support)
if grammarByLevel, exists := variety.GrammarFocusByLevel[level]; exists && len(grammarByLevel) > 0 {
selectors = append(selectors, varietySelector{
name: "grammar_focus",
selector: func() string {
// Check for grammar gaps first
if selected := getWeightedSelection("grammar_focus", grammarByLevel); selected != "" {
elements.GrammarFocus = selected
return elements.GrammarFocus
}
// Fallback to random
elements.GrammarFocus = grammarByLevel[rand.Intn(len(grammarByLevel))]
return elements.GrammarFocus
},
})
} else if len(variety.GrammarFocus) > 0 {
selectors = append(selectors, varietySelector{
name: "grammar_focus",
selector: func() string {
// Check for grammar gaps first
if selected := getWeightedSelection("grammar_focus", variety.GrammarFocus); selected != "" {
elements.GrammarFocus = selected
return elements.GrammarFocus
}
// Fallback to random
elements.GrammarFocus = variety.GrammarFocus[rand.Intn(len(variety.GrammarFocus))]
return elements.GrammarFocus
},
})
}
// Vocabulary domain selector (now with gap analysis support)
if len(variety.VocabularyDomains) > 0 {
selectors = append(selectors, varietySelector{
name: "vocabulary_domain",
selector: func() string {
// Check for vocabulary gaps first
if selected := getWeightedSelection("vocabulary_domain", variety.VocabularyDomains); selected != "" {
elements.VocabularyDomain = selected
return elements.VocabularyDomain
}
// Fallback to random
elements.VocabularyDomain = variety.VocabularyDomains[rand.Intn(len(variety.VocabularyDomains))]
return elements.VocabularyDomain
},
})
}
// Scenario selector (now with gap analysis support)
if len(variety.Scenarios) > 0 {
selectors = append(selectors, varietySelector{
name: "scenario",
selector: func() string {
// Check for scenario gaps first
if selected := getWeightedSelection("scenario", variety.Scenarios); selected != "" {
elements.Scenario = selected
return elements.Scenario
}
// Fallback to random
elements.Scenario = variety.Scenarios[rand.Intn(len(variety.Scenarios))]
return elements.Scenario
},
})
}
// Style modifier selector
if len(variety.StyleModifiers) > 0 {
selectors = append(selectors, varietySelector{
name: "style_modifier",
selector: func() string {
elements.StyleModifier = variety.StyleModifiers[rand.Intn(len(variety.StyleModifiers))]
return elements.StyleModifier
},
})
}
// Difficulty modifier selector
if len(variety.DifficultyModifiers) > 0 {
selectors = append(selectors, varietySelector{
name: "difficulty_modifier",
selector: func() string {
elements.DifficultyModifier = variety.DifficultyModifiers[rand.Intn(len(variety.DifficultyModifiers))]
return elements.DifficultyModifier
},
})
}
// Time context selector
if len(variety.TimeContexts) > 0 {
selectors = append(selectors, varietySelector{
name: "time_context",
selector: func() string {
elements.TimeContext = variety.TimeContexts[rand.Intn(len(variety.TimeContexts))]
return elements.TimeContext
},
})
}
// Randomly select 2-3 variety elements (instead of all 7)
numToSelect := 2
if len(selectors) > 2 {
// 70% chance of 2 elements, 30% chance of 3 elements
if rand.Float64() < 0.3 {
numToSelect = 3
}
}
// Shuffle and select the first numToSelect elements
rand.Shuffle(len(selectors), func(i, j int) {
selectors[i], selectors[j] = selectors[j], selectors[i]
})
// Apply the selected variety elements
for i := 0; i < numToSelect && i < len(selectors); i++ {
selected := selectors[i].selector()
span.SetAttributes(attribute.String("variety."+selectors[i].name, selected))
}
span.SetAttributes(
attribute.String("variety.topic_category", elements.TopicCategory),
attribute.String("variety.grammar_focus", elements.GrammarFocus),
attribute.String("variety.vocabulary_domain", elements.VocabularyDomain),
attribute.String("variety.scenario", elements.Scenario),
attribute.String("variety.style_modifier", elements.StyleModifier),
attribute.String("variety.difficulty_modifier", elements.DifficultyModifier),
attribute.String("variety.time_context", elements.TimeContext),
attribute.Int("variety.elements_selected", numToSelect),
)
span.SetAttributes(attribute.String("variety.result", "success"))
return elements
}
span.SetAttributes(attribute.String("variety.result", "no_config"))
return &VarietyElements{} // Return empty if no variety config
}
// SelectMultipleVarietyElements selects multiple sets of variety elements for batch generation
func (vs *VarietyService) SelectMultipleVarietyElements(ctx context.Context, level string, count int) []*VarietyElements {
ctx, span := observability.TraceVarietyFunction(ctx, "select_multiple_variety_elements",
attribute.String("variety.level", level),
attribute.Int("variety.count", count),
)
defer span.End()
elements := make([]*VarietyElements, count)
for i := 0; i < count; i++ {
elements[i] = vs.SelectVarietyElements(ctx, level, nil, nil, nil)
}
span.SetAttributes(attribute.String("variety.result", "success"), attribute.Int("variety.elements_count", len(elements)))
return elements
}
package services
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"math/rand"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// WordOfTheDayServiceInterface defines the interface for word of the day operations
type WordOfTheDayServiceInterface interface {
GetWordOfTheDay(ctx context.Context, userID int, date time.Time) (*models.WordOfTheDayDisplay, error)
SelectWordOfTheDay(ctx context.Context, userID int, date time.Time) (*models.WordOfTheDayDisplay, error)
GetWordHistory(ctx context.Context, userID int, startDate, endDate time.Time) ([]*models.WordOfTheDayDisplay, error)
}
// WordOfTheDayService implements word of the day operations
type WordOfTheDayService struct {
db *sql.DB
logger *observability.Logger
}
// ErrNoSuitableWord indicates there was no suitable word available for the user/date.
var ErrNoSuitableWord = errors.New("no suitable word found")
// NewWordOfTheDayService creates a new WordOfTheDayService instance
func NewWordOfTheDayService(db *sql.DB, logger *observability.Logger) *WordOfTheDayService {
return &WordOfTheDayService{
db: db,
logger: logger,
}
}
// GetWordOfTheDay retrieves the word of the day for a user and date
// If not exists, it will generate one by calling SelectWordOfTheDay
func (s *WordOfTheDayService) GetWordOfTheDay(ctx context.Context, userID int, date time.Time) (*models.WordOfTheDayDisplay, error) {
ctx, span := otel.Tracer("word-of-the-day-service").Start(ctx, "GetWordOfTheDay",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer observability.FinishSpan(span, nil)
// Normalize date to just the date part (no time)
date = time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, time.UTC)
// Try to get existing word of the day
// Attach username to span (best-effort)
if u, _ := s.getUserByID(ctx, userID); u != nil {
span.SetAttributes(attribute.String("user.username", u.Username))
}
word, err := s.getWordOfTheDayFromDB(ctx, userID, date)
if err != nil && err != sql.ErrNoRows {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to get word of the day from database")
}
// If exists, return it
if word != nil {
span.SetAttributes(
attribute.String("source_type", string(word.SourceType)),
attribute.Int("source_id", word.SourceID),
)
return s.convertToDisplay(ctx, word)
}
// If not exists, generate one
s.logger.Info(ctx, "Word of the day not found, generating new one", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
return s.SelectWordOfTheDay(ctx, userID, date)
}
// SelectWordOfTheDay selects and assigns a word of the day for a user and date
func (s *WordOfTheDayService) SelectWordOfTheDay(ctx context.Context, userID int, date time.Time) (*models.WordOfTheDayDisplay, error) {
ctx, span := otel.Tracer("word-of-the-day-service").Start(ctx, "SelectWordOfTheDay",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer observability.FinishSpan(span, nil)
// Normalize date to just the date part (no time)
date = time.Date(date.Year(), date.Month(), date.Day(), 0, 0, 0, 0, time.UTC)
// Get user preferences
user, err := s.getUserByID(ctx, userID)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to get user")
}
if user == nil {
err := contextutils.ErrorWithContextf("user not found: %d", userID)
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
language := user.PreferredLanguage.String
level := user.CurrentLevel.String
if language == "" {
err := contextutils.ErrorWithContextf("user missing language preference")
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.SetAttributes(
attribute.String("language", language),
attribute.String("level", level),
attribute.String("user.username", user.Username),
)
// Randomly decide between vocabulary question (70%) or snippet (30%)
useVocabulary := rand.Float32() < 0.7
var word *models.WordOfTheDay
if useVocabulary {
word, err = s.selectVocabularyQuestion(ctx, userID, language, level, date)
if err != nil || word == nil {
s.logger.Warn(ctx, "Failed to select vocabulary question, trying snippet instead", map[string]interface{}{
"error": err,
})
// Fallback to snippet
word, err = s.selectSnippet(ctx, userID, language, date)
}
} else {
word, err = s.selectSnippet(ctx, userID, language, date)
if err != nil || word == nil {
s.logger.Warn(ctx, "Failed to select snippet, trying vocabulary question instead", map[string]interface{}{
"error": err,
})
// Fallback to vocabulary question
word, err = s.selectVocabularyQuestion(ctx, userID, language, level, date)
}
}
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to select word of the day")
}
if word == nil {
// No available word is a normal condition: surface as a typed sentinel without error status
span.SetAttributes(attribute.Bool("no_word_available", true))
return nil, ErrNoSuitableWord
}
// Save to database
err = s.saveWordOfTheDay(ctx, word)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to save word of the day")
}
span.SetAttributes(
attribute.String("source_type", string(word.SourceType)),
attribute.Int("source_id", word.SourceID),
)
s.logger.Info(ctx, "Word of the day selected", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"source_type": word.SourceType,
"source_id": word.SourceID,
})
return s.convertToDisplay(ctx, word)
}
// GetWordHistory retrieves word of the day history for a date range
func (s *WordOfTheDayService) GetWordHistory(ctx context.Context, userID int, startDate, endDate time.Time) ([]*models.WordOfTheDayDisplay, error) {
ctx, span := otel.Tracer("word-of-the-day-service").Start(ctx, "GetWordHistory",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("start_date", startDate.Format("2006-01-02")),
attribute.String("end_date", endDate.Format("2006-01-02")),
),
)
defer observability.FinishSpan(span, nil)
if u, _ := s.getUserByID(ctx, userID); u != nil {
span.SetAttributes(attribute.String("user.username", u.Username))
}
// Normalize dates
startDate = time.Date(startDate.Year(), startDate.Month(), startDate.Day(), 0, 0, 0, 0, time.UTC)
endDate = time.Date(endDate.Year(), endDate.Month(), endDate.Day(), 0, 0, 0, 0, time.UTC)
query := `
SELECT id, user_id, assignment_date, source_type, source_id, created_at
FROM word_of_the_day
WHERE user_id = $1 AND assignment_date >= $2 AND assignment_date <= $3
ORDER BY assignment_date DESC
`
rows, err := s.db.QueryContext(ctx, query, userID, startDate, endDate)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to query word history")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
span.RecordError(closeErr, trace.WithStackTrace(true))
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var words []*models.WordOfTheDay
for rows.Next() {
var w models.WordOfTheDay
err := rows.Scan(&w.ID, &w.UserID, &w.AssignmentDate, &w.SourceType, &w.SourceID, &w.CreatedAt)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to scan word row")
}
words = append(words, &w)
}
if err = rows.Err(); err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "error iterating word rows")
}
// Convert to display format
var displays []*models.WordOfTheDayDisplay
for _, w := range words {
display, err := s.convertToDisplay(ctx, w)
if err != nil {
s.logger.Error(ctx, "Failed to convert word to display", err, map[string]interface{}{
"word_id": w.ID,
"source_type": w.SourceType,
"source_id": w.SourceID,
})
continue
}
displays = append(displays, display)
}
span.SetAttributes(attribute.Int("count", len(displays)))
return displays, nil
}
// selectVocabularyQuestion selects a vocabulary question for word of the day
func (s *WordOfTheDayService) selectVocabularyQuestion(ctx context.Context, userID int, language, level string, date time.Time) (*models.WordOfTheDay, error) {
ctx, span := otel.Tracer("word-of-the-day-service").Start(ctx, "selectVocabularyQuestion",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("language", language),
attribute.String("level", level),
),
)
defer observability.FinishSpan(span, nil)
if u, _ := s.getUserByID(ctx, userID); u != nil {
span.SetAttributes(attribute.String("user.username", u.Username))
}
// Query for vocabulary questions that haven't been used as word of the day recently
query := `
SELECT q.id
FROM questions q
WHERE q.type = 'vocabulary'
AND q.language = $1
AND q.status = 'active'
AND ($2 = '' OR q.level = $2)
AND NOT EXISTS (
SELECT 1 FROM word_of_the_day wotd
WHERE wotd.user_id = $3
AND wotd.source_type = 'vocabulary_question'
AND wotd.source_id = q.id
AND wotd.assignment_date > $4
)
ORDER BY RANDOM()
LIMIT 1
`
// Don't reuse words from the last 60 days
cutoffDate := date.AddDate(0, 0, -60)
var questionID int
err := s.db.QueryRowContext(ctx, query, language, level, userID, cutoffDate).Scan(&questionID)
if err == sql.ErrNoRows {
// Try without the recency check
queryNoRecency := `
SELECT q.id
FROM questions q
WHERE q.type = 'vocabulary'
AND q.language = $1
AND q.status = 'active'
AND ($2 = '' OR q.level = $2)
ORDER BY RANDOM()
LIMIT 1
`
err = s.db.QueryRowContext(ctx, queryNoRecency, language, level).Scan(&questionID)
}
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // No vocabulary questions available
}
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to query vocabulary question")
}
return &models.WordOfTheDay{
UserID: userID,
AssignmentDate: date,
SourceType: models.WordSourceVocabularyQuestion,
SourceID: questionID,
}, nil
}
// selectSnippet selects a user snippet for word of the day
func (s *WordOfTheDayService) selectSnippet(ctx context.Context, userID int, language string, date time.Time) (*models.WordOfTheDay, error) {
ctx, span := otel.Tracer("word-of-the-day-service").Start(ctx, "selectSnippet",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("language", language),
),
)
defer observability.FinishSpan(span, nil)
if u, _ := s.getUserByID(ctx, userID); u != nil {
span.SetAttributes(attribute.String("user.username", u.Username))
}
// Query for user's snippets that haven't been used as word of the day recently
// Prefer more recent snippets (created in last 30 days)
query := `
SELECT s.id
FROM snippets s
WHERE s.user_id = $1
AND s.source_language = $2
AND NOT EXISTS (
SELECT 1 FROM word_of_the_day wotd
WHERE wotd.user_id = $1
AND wotd.source_type = 'snippet'
AND wotd.source_id = s.id
AND wotd.assignment_date > $3
)
ORDER BY
CASE WHEN s.created_at > $4 THEN 0 ELSE 1 END,
RANDOM()
LIMIT 1
`
// Don't reuse snippets from the last 60 days
cutoffDate := date.AddDate(0, 0, -60)
// Prefer snippets from the last 30 days
recentCutoff := date.AddDate(0, 0, -30)
var snippetID int
err := s.db.QueryRowContext(ctx, query, userID, language, cutoffDate, recentCutoff).Scan(&snippetID)
if err == sql.ErrNoRows {
// Try without the recency check
queryNoRecency := `
SELECT s.id
FROM snippets s
WHERE s.user_id = $1
AND s.source_language = $2
ORDER BY RANDOM()
LIMIT 1
`
err = s.db.QueryRowContext(ctx, queryNoRecency, userID, language).Scan(&snippetID)
}
if err != nil {
if err == sql.ErrNoRows {
return nil, nil // No snippets available
}
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to query snippet")
}
return &models.WordOfTheDay{
UserID: userID,
AssignmentDate: date,
SourceType: models.WordSourceSnippet,
SourceID: snippetID,
}, nil
}
// getWordOfTheDayFromDB retrieves a word of the day from the database
func (s *WordOfTheDayService) getWordOfTheDayFromDB(ctx context.Context, userID int, date time.Time) (*models.WordOfTheDay, error) {
query := `
SELECT id, user_id, assignment_date, source_type, source_id, created_at
FROM word_of_the_day
WHERE user_id = $1 AND assignment_date = $2
`
var w models.WordOfTheDay
err := s.db.QueryRowContext(ctx, query, userID, date).Scan(
&w.ID, &w.UserID, &w.AssignmentDate, &w.SourceType, &w.SourceID, &w.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, sql.ErrNoRows
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to query word of the day")
}
return &w, nil
}
// saveWordOfTheDay saves a word of the day to the database
func (s *WordOfTheDayService) saveWordOfTheDay(ctx context.Context, word *models.WordOfTheDay) error {
query := `
INSERT INTO word_of_the_day (user_id, assignment_date, source_type, source_id, created_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (user_id, assignment_date) DO NOTHING
RETURNING id
`
err := s.db.QueryRowContext(ctx, query,
word.UserID,
word.AssignmentDate,
word.SourceType,
word.SourceID,
time.Now(),
).Scan(&word.ID)
if err != nil {
return contextutils.WrapError(err, "failed to insert word of the day")
}
return nil
}
// convertToDisplay converts a WordOfTheDay to WordOfTheDayDisplay format
func (s *WordOfTheDayService) convertToDisplay(ctx context.Context, word *models.WordOfTheDay) (*models.WordOfTheDayDisplay, error) {
ctx, span := otel.Tracer("word-of-the-day-service").Start(ctx, "convertToDisplay")
defer observability.FinishSpan(span, nil)
if u, _ := s.getUserByID(ctx, word.UserID); u != nil {
span.SetAttributes(
attribute.Int("user.id", u.ID),
attribute.String("user.username", u.Username),
)
}
display := &models.WordOfTheDayDisplay{
Date: word.AssignmentDate,
SourceType: word.SourceType,
SourceID: word.SourceID,
}
switch word.SourceType {
case models.WordSourceVocabularyQuestion:
question, err := s.getQuestionByID(ctx, word.SourceID)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to get question")
}
// Extract word, translation, and sentence from question content
content := question.Content
if sentenceRaw, ok := content["sentence"]; ok {
display.Sentence = fmt.Sprintf("%v", sentenceRaw)
}
if questionRaw, ok := content["question"]; ok {
display.Word = fmt.Sprintf("%v", questionRaw)
}
if optionsRaw, ok := content["options"]; ok {
if options, ok := optionsRaw.([]interface{}); ok && len(options) > question.CorrectAnswer {
display.Translation = fmt.Sprintf("%v", options[question.CorrectAnswer])
}
}
display.Language = question.Language
display.Level = question.Level
display.Explanation = question.Explanation
display.TopicCategory = question.TopicCategory
case models.WordSourceSnippet:
snippet, err := s.getSnippetByID(ctx, word.SourceID)
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
return nil, contextutils.WrapError(err, "failed to get snippet")
}
display.Word = snippet.OriginalText
display.Translation = snippet.TranslatedText
display.Language = snippet.SourceLanguage
if snippet.Context != nil {
display.Context = *snippet.Context
display.Sentence = *snippet.Context
}
if snippet.DifficultyLevel != nil {
display.Level = *snippet.DifficultyLevel
}
}
return display, nil
}
// getUserByID retrieves a user by ID
func (s *WordOfTheDayService) getUserByID(ctx context.Context, userID int) (*models.User, error) {
query := `
SELECT id, username, email, preferred_language, current_level, timezone
FROM users
WHERE id = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID,
&user.Username,
&user.Email,
&user.PreferredLanguage,
&user.CurrentLevel,
&user.Timezone,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to query user")
}
return &user, nil
}
// getQuestionByID retrieves a question by ID
func (s *WordOfTheDayService) getQuestionByID(ctx context.Context, questionID int) (*models.Question, error) {
query := `
SELECT id, type, language, level, difficulty_score, content, correct_answer,
explanation, created_at, status, topic_category, grammar_focus,
vocabulary_domain, scenario, style_modifier, difficulty_modifier, time_context
FROM questions
WHERE id = $1
`
var question models.Question
var contentJSON []byte
err := s.db.QueryRowContext(ctx, query, questionID).Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&question.TopicCategory,
&question.GrammarFocus,
&question.VocabularyDomain,
&question.Scenario,
&question.StyleModifier,
&question.DifficultyModifier,
&question.TimeContext,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query question")
}
// Parse JSON content
content := make(map[string]interface{})
if err := json.Unmarshal(contentJSON, &content); err != nil {
return nil, contextutils.WrapError(err, "failed to parse question content")
}
question.Content = content
return &question, nil
}
// getSnippetByID retrieves a snippet by ID
func (s *WordOfTheDayService) getSnippetByID(ctx context.Context, snippetID int) (*models.Snippet, error) {
query := `
SELECT id, user_id, original_text, translated_text, source_language,
target_language, question_id, section_id, story_id, context,
difficulty_level, created_at, updated_at
FROM snippets
WHERE id = $1
`
var snippet models.Snippet
err := s.db.QueryRowContext(ctx, query, snippetID).Scan(
&snippet.ID,
&snippet.UserID,
&snippet.OriginalText,
&snippet.TranslatedText,
&snippet.SourceLanguage,
&snippet.TargetLanguage,
&snippet.QuestionID,
&snippet.SectionID,
&snippet.StoryID,
&snippet.Context,
&snippet.DifficultyLevel,
&snippet.CreatedAt,
&snippet.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query snippet")
}
return &snippet, nil
}
package services
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
)
// ErrSettingNotFound is returned when a setting is not found in the database
var ErrSettingNotFound = errors.New("setting not found")
// WorkerServiceInterface defines the interface for worker management operations
type WorkerServiceInterface interface {
// Settings management
GetSetting(ctx context.Context, key string) (string, error)
SetSetting(ctx context.Context, key, value string) error
IsGlobalPaused(ctx context.Context) (bool, error)
SetGlobalPause(ctx context.Context, paused bool) error
IsUserPaused(ctx context.Context, userID int) (bool, error)
SetUserPause(ctx context.Context, userID int, paused bool) error
// Status management
UpdateWorkerStatus(ctx context.Context, instance string, status *models.WorkerStatus) error
GetWorkerStatus(ctx context.Context, instance string) (*models.WorkerStatus, error)
GetAllWorkerStatuses(ctx context.Context) ([]models.WorkerStatus, error)
UpdateHeartbeat(ctx context.Context, instance string) error
IsWorkerHealthy(ctx context.Context, instance string) (bool, error)
// Control operations
PauseWorker(ctx context.Context, instance string) error
ResumeWorker(ctx context.Context, instance string) error
GetWorkerHealth(ctx context.Context) (map[string]interface{}, error)
GetHighPriorityTopics(ctx context.Context, userID int, language, level, questionType string) ([]string, error)
GetGapAnalysis(ctx context.Context, userID int, language, level, questionType string) (map[string]int, error)
GetPriorityDistribution(ctx context.Context, userID int, language, level, questionType string) (map[string]int, error)
// Notification management
GetNotificationStats(ctx context.Context) (map[string]interface{}, error)
GetNotificationErrors(ctx context.Context, page, pageSize int, errorType, notificationType, resolved string) ([]map[string]interface{}, map[string]interface{}, map[string]interface{}, error)
GetUpcomingNotifications(ctx context.Context, page, pageSize int, notificationType, status, scheduledAfter, scheduledBefore string) ([]map[string]interface{}, map[string]interface{}, map[string]interface{}, error)
GetSentNotifications(ctx context.Context, page, pageSize int, notificationType, status, sentAfter, sentBefore string) ([]map[string]interface{}, map[string]interface{}, map[string]interface{}, error)
// Test methods for creating test data
CreateTestSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) error
}
// WorkerService implements worker management operations
type WorkerService struct {
db *sql.DB
logger *observability.Logger
}
// NewWorkerServiceWithLogger creates a new WorkerService instance with logger
func NewWorkerServiceWithLogger(db *sql.DB, logger *observability.Logger) *WorkerService {
return &WorkerService{
db: db,
logger: logger,
}
}
// GetSetting retrieves a setting value by key
func (s *WorkerService) GetSetting(ctx context.Context, key string) (result0 string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_setting", attribute.String("setting.key", key))
defer observability.FinishSpan(span, &err)
// Validate key
if len(key) == 0 || len(strings.TrimSpace(key)) == 0 {
return "", contextutils.WrapErrorf(errors.New("invalid setting key"), "setting key cannot be empty")
}
var value string
err = s.db.QueryRowContext(ctx, `
SELECT setting_value FROM worker_settings WHERE setting_key = $1
`, key).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Debug(ctx, "Setting not found", map[string]interface{}{"setting_key": key})
return "", contextutils.WrapErrorf(ErrSettingNotFound, "%s", key)
}
s.logger.Error(ctx, "Failed to get setting", err, map[string]interface{}{"setting_key": key})
return "", contextutils.WrapErrorf(err, "failed to get setting %s", key)
}
return value, nil
}
// SetSetting updates or creates a setting
func (s *WorkerService) SetSetting(ctx context.Context, key, value string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "set_setting", attribute.String("setting.key", key))
defer observability.FinishSpan(span, &err)
// Validate key
if len(key) == 0 || len(strings.TrimSpace(key)) == 0 {
return contextutils.WrapErrorf(errors.New("invalid setting key"), "setting key cannot be empty")
}
_, err = s.db.ExecContext(ctx, `
INSERT INTO worker_settings (setting_key, setting_value, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (setting_key) DO UPDATE SET
setting_value = EXCLUDED.setting_value,
updated_at = EXCLUDED.updated_at
`, key, value)
if err != nil {
s.logger.Error(ctx, "Failed to set setting", err, map[string]interface{}{"setting_key": key, "setting_value": value})
return contextutils.WrapErrorf(err, "failed to set setting %s", key)
}
s.logger.Debug(ctx, "Setting updated", map[string]interface{}{"setting_key": key, "setting_value": value})
return nil
}
// IsGlobalPaused checks if the worker is globally paused
func (s *WorkerService) IsGlobalPaused(ctx context.Context) (result0 bool, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "is_global_paused")
defer observability.FinishSpan(span, &err)
var value string
value, err = s.GetSetting(ctx, "global_pause")
if err != nil {
// If setting doesn't exist, default to false (not paused)
if errors.Is(err, ErrSettingNotFound) {
// Initialize the setting with default value
if setErr := s.SetSetting(ctx, "global_pause", "false"); setErr != nil {
s.logger.Error(ctx, "Failed to initialize global_pause setting", setErr, map[string]interface{}{})
return false, contextutils.WrapError(setErr, "failed to initialize global_pause setting")
}
return false, nil
}
s.logger.Error(ctx, "Failed to check global pause status", err, map[string]interface{}{})
return false, err
}
paused := value == "true"
s.logger.Debug(ctx, "Global pause status checked", map[string]interface{}{"global_paused": paused})
return paused, nil
}
// SetGlobalPause sets the global pause state
func (s *WorkerService) SetGlobalPause(ctx context.Context, paused bool) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "set_global_pause", attribute.Bool("paused", paused))
defer observability.FinishSpan(span, &err)
value := "false"
if paused {
value = "true"
}
err = s.SetSetting(ctx, "global_pause", value)
if err != nil {
return err
}
s.logger.Info(ctx, "Global pause state updated", map[string]interface{}{"global_paused": paused})
return nil
}
// IsUserPaused checks if a specific user is paused
func (s *WorkerService) IsUserPaused(ctx context.Context, userID int) (result0 bool, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "is_user_paused", observability.AttributeUserID(userID))
defer observability.FinishSpan(span, &err)
key := fmt.Sprintf("user_pause_%d", userID)
var value string
err = s.db.QueryRowContext(ctx, `
SELECT setting_value FROM worker_settings WHERE setting_key = $1
`, key).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
// If setting doesn't exist, user is not paused (this is the default state)
s.logger.Debug(ctx, "User pause setting not found, defaulting to not paused", map[string]interface{}{"user_id": userID})
return false, nil
}
s.logger.Error(ctx, "Failed to check user pause status", err, map[string]interface{}{"user_id": userID})
return false, contextutils.WrapErrorf(err, "failed to check user pause status for user %d", userID)
}
paused := value == "true"
s.logger.Debug(ctx, "User pause status checked", map[string]interface{}{"user_id": userID, "user_paused": paused})
return paused, nil
}
// SetUserPause sets the pause state for a specific user
func (s *WorkerService) SetUserPause(ctx context.Context, userID int, paused bool) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "set_user_pause", observability.AttributeUserID(userID), attribute.Bool("paused", paused))
defer observability.FinishSpan(span, &err)
key := fmt.Sprintf("user_pause_%d", userID)
value := "false"
if paused {
value = "true"
}
err = s.SetSetting(ctx, key, value)
if err != nil {
return err
}
s.logger.Info(ctx, "User pause state updated", map[string]interface{}{"user_id": userID, "user_paused": paused})
return nil
}
// UpdateWorkerStatus updates the worker status in the database
func (s *WorkerService) UpdateWorkerStatus(ctx context.Context, instance string, status *models.WorkerStatus) (err error) {
activity := ""
if status.CurrentActivity.Valid {
activity = status.CurrentActivity.String
}
ctx, span := observability.TraceWorkerFunction(ctx, "update_worker_status",
attribute.String("worker.instance", instance),
attribute.Bool("worker.is_running", status.IsRunning),
attribute.Bool("worker.is_paused", status.IsPaused),
attribute.String("worker.activity", activity),
)
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
INSERT INTO worker_status (
worker_instance, is_running, is_paused, current_activity,
last_heartbeat, last_run_start, last_run_finish, last_run_error,
total_questions_generated, total_runs, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
ON CONFLICT (worker_instance) DO UPDATE SET
is_running = EXCLUDED.is_running,
is_paused = EXCLUDED.is_paused,
current_activity = EXCLUDED.current_activity,
last_heartbeat = EXCLUDED.last_heartbeat,
last_run_start = EXCLUDED.last_run_start,
last_run_finish = EXCLUDED.last_run_finish,
last_run_error = EXCLUDED.last_run_error,
total_questions_generated = EXCLUDED.total_questions_generated,
total_runs = EXCLUDED.total_runs,
updated_at = EXCLUDED.updated_at
`, instance, status.IsRunning, status.IsPaused, status.CurrentActivity,
status.LastHeartbeat, status.LastRunStart, status.LastRunFinish,
status.LastRunError, status.TotalQuestionsGenerated, status.TotalRuns)
if err != nil {
s.logger.Error(ctx, "Failed to update worker status", err, map[string]interface{}{
"worker_instance": instance,
"is_running": status.IsRunning,
"is_paused": status.IsPaused,
"activity": activity,
})
err = contextutils.WrapErrorf(err, "failed to update worker status for instance %s", instance)
return err
}
s.logger.Debug(ctx, "Worker status updated", map[string]interface{}{
"worker_instance": instance,
"is_running": status.IsRunning,
"is_paused": status.IsPaused,
"activity": activity,
})
return nil
}
// GetWorkerStatus retrieves worker status by instance
func (s *WorkerService) GetWorkerStatus(ctx context.Context, instance string) (result0 *models.WorkerStatus, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_worker_status", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
var status models.WorkerStatus
err = s.db.QueryRowContext(ctx, `
SELECT id, worker_instance, is_running, is_paused, current_activity,
last_heartbeat, last_run_start, last_run_finish, last_run_error,
total_questions_generated, total_runs, created_at, updated_at
FROM worker_status WHERE worker_instance = $1
`, instance).Scan(
&status.ID, &status.WorkerInstance, &status.IsRunning, &status.IsPaused,
&status.CurrentActivity, &status.LastHeartbeat, &status.LastRunStart,
&status.LastRunFinish, &status.LastRunError, &status.TotalQuestionsGenerated,
&status.TotalRuns, &status.CreatedAt, &status.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Debug(ctx, "Worker status not found", map[string]interface{}{"worker_instance": instance})
return nil, contextutils.WrapErrorf(err, "worker status not found for instance %s", instance)
}
s.logger.Error(ctx, "Failed to get worker status", err, map[string]interface{}{"worker_instance": instance})
return nil, contextutils.WrapErrorf(err, "failed to get worker status for instance %s", instance)
}
return &status, nil
}
// GetAllWorkerStatuses retrieves all worker statuses
func (s *WorkerService) GetAllWorkerStatuses(ctx context.Context) (result0 []models.WorkerStatus, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_all_worker_statuses")
defer observability.FinishSpan(span, &err)
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, `
SELECT id, worker_instance, is_running, is_paused, current_activity,
last_heartbeat, last_run_start, last_run_finish, last_run_error,
total_questions_generated, total_runs, created_at, updated_at
FROM worker_status ORDER BY worker_instance
`)
if err != nil {
s.logger.Error(ctx, "Failed to get all worker statuses", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get all worker statuses")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
var statuses []models.WorkerStatus
for rows.Next() {
var status models.WorkerStatus
err = rows.Scan(
&status.ID, &status.WorkerInstance, &status.IsRunning, &status.IsPaused,
&status.CurrentActivity, &status.LastHeartbeat, &status.LastRunStart,
&status.LastRunFinish, &status.LastRunError, &status.TotalQuestionsGenerated,
&status.TotalRuns, &status.CreatedAt, &status.UpdatedAt,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan worker status row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan worker status row")
}
statuses = append(statuses, status)
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating worker status rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating worker status rows")
}
s.logger.Debug(ctx, "Retrieved all worker statuses", map[string]interface{}{"count": len(statuses)})
return statuses, nil
}
// UpdateHeartbeat updates the heartbeat for a worker instance
func (s *WorkerService) UpdateHeartbeat(ctx context.Context, instance string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "update_heartbeat", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
INSERT INTO worker_status (worker_instance, last_heartbeat, updated_at)
VALUES ($1, NOW(), NOW())
ON CONFLICT (worker_instance) DO UPDATE SET
last_heartbeat = EXCLUDED.last_heartbeat,
updated_at = EXCLUDED.updated_at
`, instance)
if err != nil {
s.logger.Error(ctx, "Failed to update heartbeat", err, map[string]interface{}{"worker_instance": instance})
return contextutils.WrapErrorf(err, "failed to update heartbeat for instance %s", instance)
}
s.logger.Debug(ctx, "Heartbeat updated", map[string]interface{}{"worker_instance": instance})
return nil
}
// IsWorkerHealthy checks if a worker instance is healthy based on recent heartbeat
func (s *WorkerService) IsWorkerHealthy(ctx context.Context, instance string) (result0 bool, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "is_worker_healthy", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
var lastHeartbeat sql.NullTime
err = s.db.QueryRowContext(ctx, `
SELECT last_heartbeat FROM worker_status WHERE worker_instance = $1
`, instance).Scan(&lastHeartbeat)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Debug(ctx, "Worker not found, considered unhealthy", map[string]interface{}{"worker_instance": instance})
return false, nil
}
s.logger.Error(ctx, "Failed to check worker health", err, map[string]interface{}{"worker_instance": instance})
return false, contextutils.WrapErrorf(err, "failed to check worker health for instance %s", instance)
}
if !lastHeartbeat.Valid {
s.logger.Debug(ctx, "Worker has no heartbeat, considered unhealthy", map[string]interface{}{"worker_instance": instance})
return false, nil
}
// Consider worker healthy if heartbeat is within the last 5 minutes
healthy := time.Since(lastHeartbeat.Time) < 5*time.Minute
s.logger.Debug(ctx, "Worker health checked", map[string]interface{}{
"worker_instance": instance,
"healthy": healthy,
"last_heartbeat": lastHeartbeat.Time,
"time_since": time.Since(lastHeartbeat.Time).String(),
})
return healthy, nil
}
// PauseWorker pauses a specific worker instance
func (s *WorkerService) PauseWorker(ctx context.Context, instance string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "pause_worker", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
UPDATE worker_status SET is_paused = true, updated_at = NOW()
WHERE worker_instance = $1
`, instance)
if err != nil {
s.logger.Error(ctx, "Failed to pause worker", err, map[string]interface{}{"worker_instance": instance})
return contextutils.WrapErrorf(err, "failed to pause worker instance %s", instance)
}
s.logger.Info(ctx, "Worker paused", map[string]interface{}{"worker_instance": instance})
return nil
}
// ResumeWorker resumes a specific worker instance
func (s *WorkerService) ResumeWorker(ctx context.Context, instance string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "resume_worker", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
UPDATE worker_status SET is_paused = false, updated_at = NOW()
WHERE worker_instance = $1
`, instance)
if err != nil {
s.logger.Error(ctx, "Failed to resume worker", err, map[string]interface{}{"worker_instance": instance})
return contextutils.WrapErrorf(err, "failed to resume worker instance %s", instance)
}
s.logger.Info(ctx, "Worker resumed", map[string]interface{}{"worker_instance": instance})
return nil
}
// GetWorkerHealth returns a map of worker health information
func (s *WorkerService) GetWorkerHealth(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_worker_health")
defer observability.FinishSpan(span, &err)
var statuses []models.WorkerStatus
statuses, err = s.GetAllWorkerStatuses(ctx)
if err != nil {
return nil, err
}
var globalPaused bool
globalPaused, err = s.IsGlobalPaused(ctx)
if err != nil {
s.logger.Error(ctx, "Failed to get global pause state", err, map[string]interface{}{})
globalPaused = false // Default to false if we can't get the state
}
health := make(map[string]interface{})
workerInstances := make([]map[string]interface{}, 0)
healthyCount := 0
totalCount := len(statuses)
for _, status := range statuses {
healthy, err := s.IsWorkerHealthy(ctx, status.WorkerInstance)
if err != nil {
s.logger.Error(ctx, "Failed to check health for worker", err, map[string]interface{}{"worker_instance": status.WorkerInstance})
continue
}
if healthy {
healthyCount++
}
workerInstance := map[string]interface{}{
"worker_instance": status.WorkerInstance,
"healthy": healthy,
"is_running": status.IsRunning,
"is_paused": status.IsPaused,
"last_heartbeat": status.LastHeartbeat,
"total_questions_generated": status.TotalQuestionsGenerated,
"total_runs": status.TotalRuns,
}
workerInstances = append(workerInstances, workerInstance)
}
// Build comprehensive health summary
health["global_paused"] = globalPaused
health["worker_instances"] = workerInstances
health["total_count"] = totalCount
health["healthy_count"] = healthyCount
s.logger.Debug(ctx, "Worker health retrieved", map[string]interface{}{"worker_count": len(health)})
return health, nil
}
// GetHighPriorityTopics returns topics with high average priority scores for a user
func (s *WorkerService) GetHighPriorityTopics(ctx context.Context, userID int, language, level, questionType string) (result0 []string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_high_priority_topics",
observability.AttributeUserID(userID),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("question.type", questionType),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT q.topic_category, AVG(qps.priority_score) as avg_score
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.topic_category IS NOT NULL
AND q.topic_category != ''
GROUP BY q.topic_category
HAVING AVG(qps.priority_score) >= 7.0
ORDER BY avg_score DESC
LIMIT 5
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, questionType)
if err != nil {
s.logger.Error(ctx, "Failed to get high priority topics", err, map[string]interface{}{
"user_id": userID, "language": language, "level": level, "question_type": questionType,
})
return nil, contextutils.WrapError(err, "failed to get high priority topics")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
var topics []string
for rows.Next() {
var topic string
var avgScore float64
if err := rows.Scan(&topic, &avgScore); err != nil {
s.logger.Error(ctx, "Failed to scan high priority topics row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan high priority topics row")
}
topics = append(topics, topic)
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating high priority topics rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating high priority topics rows")
}
s.logger.Debug(ctx, "Retrieved high priority topics", map[string]interface{}{"user_id": userID, "count": len(topics)})
return topics, nil
}
// GetGapAnalysis identifies areas with poor user performance (knowledge gaps)
func (s *WorkerService) GetGapAnalysis(ctx context.Context, userID int, language, level, questionType string) (result0 map[string]int, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_gap_analysis",
observability.AttributeUserID(userID),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("question.type", questionType),
)
defer observability.FinishSpan(span, &err)
// Query to find areas where user has poor performance (low accuracy)
// This analyzes gaps in user's knowledge across topics and varieties
query := `
WITH user_performance AS (
SELECT
q.topic_category,
q.grammar_focus,
q.vocabulary_domain,
q.scenario,
COUNT(*) as total_questions,
COUNT(CASE WHEN ur.is_correct = true THEN 1 END) as correct_answers,
ROUND(
COUNT(CASE WHEN ur.is_correct = true THEN 1 END)::decimal / COUNT(*)::decimal * 100, 2
) as accuracy_percentage
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN user_responses ur ON q.id = ur.question_id AND ur.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
GROUP BY q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario
)
SELECT
COALESCE(topic_category, 'unknown') as area,
'topic' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE accuracy_percentage < 60 OR accuracy_percentage IS NULL
UNION ALL
SELECT
COALESCE(grammar_focus, 'unknown') as area,
'grammar' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE (accuracy_percentage < 60 OR accuracy_percentage IS NULL) AND grammar_focus IS NOT NULL
UNION ALL
SELECT
COALESCE(vocabulary_domain, 'unknown') as area,
'vocabulary' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE (accuracy_percentage < 60 OR accuracy_percentage IS NULL) AND vocabulary_domain IS NOT NULL
UNION ALL
SELECT
COALESCE(scenario, 'unknown') as area,
'scenario' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE (accuracy_percentage < 60 OR accuracy_percentage IS NULL) AND scenario IS NOT NULL
ORDER BY accuracy_percentage ASC, total_questions DESC
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, questionType)
if err != nil {
s.logger.Error(ctx, "Failed to get gap analysis", err, map[string]interface{}{
"user_id": userID, "language": language, "level": level, "question_type": questionType,
})
return nil, contextutils.WrapError(err, "failed to get gap analysis")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
gaps := make(map[string]int)
for rows.Next() {
var area, gapType string
var totalQuestions int
var accuracyPercentage sql.NullFloat64
if err := rows.Scan(&area, &gapType, &totalQuestions, &accuracyPercentage); err != nil {
s.logger.Error(ctx, "Failed to scan gap analysis row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan gap analysis row")
}
// Create a key that includes the gap type for better identification
key := fmt.Sprintf("%s_%s", gapType, area)
// Use the number of questions as the gap severity indicator
// Areas with more questions but poor performance are bigger gaps
gaps[key] = totalQuestions
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating gap analysis rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating gap analysis rows")
}
s.logger.Debug(ctx, "Retrieved gap analysis", map[string]interface{}{"user_id": userID, "count": len(gaps)})
return gaps, nil
}
// GetPriorityDistribution returns the distribution of priority scores by topic
func (s *WorkerService) GetPriorityDistribution(ctx context.Context, userID int, language, level, questionType string) (result0 map[string]int, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_priority_distribution",
observability.AttributeUserID(userID),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("question.type", questionType),
)
defer observability.FinishSpan(span, &err)
// Query to get priority score distribution by topic
query := `
SELECT q.topic_category, COUNT(*) as question_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
GROUP BY q.topic_category
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, questionType)
if err != nil {
s.logger.Error(ctx, "Failed to get priority distribution", err, map[string]interface{}{
"user_id": userID, "language": language, "level": level, "question_type": questionType,
})
return nil, contextutils.WrapError(err, "failed to get priority distribution")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
distribution := make(map[string]int)
for rows.Next() {
var topic string
var count int
if err := rows.Scan(&topic, &count); err != nil {
s.logger.Error(ctx, "Failed to scan priority distribution row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan priority distribution row")
}
distribution[topic] = count
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating priority distribution rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating priority distribution rows")
}
s.logger.Debug(ctx, "Retrieved priority distribution", map[string]interface{}{"user_id": userID, "count": len(distribution)})
return distribution, nil
}
// GetNotificationStats returns comprehensive notification statistics
func (s *WorkerService) GetNotificationStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_notification_stats")
defer observability.FinishSpan(span, &err)
// Get total notifications sent
var totalSent int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications WHERE status = 'sent'
`).Scan(&totalSent)
if err != nil {
s.logger.Error(ctx, "Failed to get total notifications sent", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get total notifications sent")
}
// Get total notifications failed
var totalFailed int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications WHERE status = 'failed'
`).Scan(&totalFailed)
if err != nil {
s.logger.Error(ctx, "Failed to get total notifications failed", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get total notifications failed")
}
// Calculate success rate
var successRate float64
if totalSent+totalFailed > 0 {
successRate = float64(totalSent) / float64(totalSent+totalFailed)
}
// Get users with notifications enabled
var usersWithNotifications int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(DISTINCT user_id) FROM user_learning_preferences WHERE daily_reminder_enabled = true
`).Scan(&usersWithNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get users with notifications enabled", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get users with notifications enabled")
}
// Get total users
var totalUsers int
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&totalUsers)
if err != nil {
s.logger.Error(ctx, "Failed to get total users", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get total users")
}
// Get notifications sent today
var sentToday int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND DATE(sent_at) = CURRENT_DATE
`).Scan(&sentToday)
if err != nil {
s.logger.Error(ctx, "Failed to get notifications sent today", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get notifications sent today")
}
// Get notifications sent this week
var sentThisWeek int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND sent_at >= DATE_TRUNC('week', CURRENT_DATE)
`).Scan(&sentThisWeek)
if err != nil {
s.logger.Error(ctx, "Failed to get notifications sent this week", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get notifications sent this week")
}
// Get upcoming notifications
var upcomingNotifications int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM upcoming_notifications WHERE status = 'pending'
`).Scan(&upcomingNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get upcoming notifications", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get upcoming notifications")
}
// Get unresolved errors
var unresolvedErrors int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM notification_errors WHERE resolved_at IS NULL
`).Scan(&unresolvedErrors)
if err != nil {
s.logger.Error(ctx, "Failed to get unresolved errors", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get unresolved errors")
}
// Get notifications by type
notificationsByType := make(map[string]int)
rows, err := s.db.QueryContext(ctx, `
SELECT notification_type, COUNT(*)
FROM sent_notifications
WHERE status = 'sent'
GROUP BY notification_type
`)
if err != nil {
s.logger.Error(ctx, "Failed to get notifications by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get notifications by type")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan notifications by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan notifications by type")
}
notificationsByType[notificationType] = count
}
// Get errors by type
errorsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, `
SELECT error_type, COUNT(*)
FROM notification_errors
GROUP BY error_type
`)
if err != nil {
s.logger.Error(ctx, "Failed to get errors by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get errors by type")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var errorType string
var count int
if err := rows.Scan(&errorType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan errors by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan errors by type")
}
errorsByType[errorType] = count
}
stats := map[string]interface{}{
"total_notifications_sent": totalSent,
"total_notifications_failed": totalFailed,
"success_rate": successRate,
"users_with_notifications_enabled": usersWithNotifications,
"total_users": totalUsers,
"notifications_sent_today": sentToday,
"notifications_sent_this_week": sentThisWeek,
"notifications_by_type": notificationsByType,
"errors_by_type": errorsByType,
"upcoming_notifications": upcomingNotifications,
"unresolved_errors": unresolvedErrors,
}
s.logger.Debug(ctx, "Retrieved notification stats", map[string]interface{}{"stats": stats})
return stats, nil
}
// GetNotificationErrors returns paginated notification errors with filtering
func (s *WorkerService) GetNotificationErrors(ctx context.Context, page, pageSize int, errorType, notificationType, resolved string) (result0 []map[string]interface{}, result1, result2 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_notification_errors",
attribute.Int("page", page),
attribute.Int("page_size", pageSize),
attribute.String("error_type", errorType),
attribute.String("notification_type", notificationType),
attribute.String("resolved", resolved),
)
defer observability.FinishSpan(span, &err)
// Build WHERE clause
whereConditions := []string{}
args := []interface{}{}
argIndex := 1
if errorType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("error_type = $%d", argIndex))
args = append(args, errorType)
argIndex++
}
if notificationType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("notification_type = $%d", argIndex))
args = append(args, notificationType)
argIndex++
}
switch resolved {
case "true":
whereConditions = append(whereConditions, "resolved_at IS NOT NULL")
case "false":
whereConditions = append(whereConditions, "resolved_at IS NULL")
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
var totalErrors int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM notification_errors %s", whereClause)
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalErrors)
if err != nil {
s.logger.Error(ctx, "Failed to get total notification errors", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get total notification errors")
}
// Calculate pagination
offset := (page - 1) * pageSize
totalPages := (totalErrors + pageSize - 1) / pageSize
// Get errors with pagination
args = append(args, pageSize, offset)
query := fmt.Sprintf(`
SELECT ne.id, ne.user_id, u.username, ne.notification_type, ne.error_type,
ne.error_message, ne.email_address, ne.occurred_at, ne.resolved_at, ne.resolution_notes
FROM notification_errors ne
LEFT JOIN users u ON ne.user_id = u.id
%s
ORDER BY ne.occurred_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
s.logger.Error(ctx, "Failed to get notification errors", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get notification errors")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
var errors []map[string]interface{}
for rows.Next() {
var errorData map[string]interface{}
var id int
var userID sql.NullInt64
var username sql.NullString
var notificationType, errorType, errorMessage string
var emailAddress sql.NullString
var occurredAt time.Time
var resolvedAt sql.NullTime
var resolutionNotes sql.NullString
err := rows.Scan(&id, &userID, &username, ¬ificationType, &errorType, &errorMessage, &emailAddress, &occurredAt, &resolvedAt, &resolutionNotes)
if err != nil {
s.logger.Error(ctx, "Failed to scan notification error", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to scan notification error")
}
errorData = map[string]interface{}{
"id": id,
"notification_type": notificationType,
"error_type": errorType,
"error_message": errorMessage,
"occurred_at": occurredAt.Format(time.RFC3339),
}
if userID.Valid {
errorData["user_id"] = userID.Int64
}
if username.Valid {
errorData["username"] = username.String
}
if emailAddress.Valid {
errorData["email_address"] = emailAddress.String
}
if resolvedAt.Valid {
errorData["resolved_at"] = resolvedAt.Time.Format(time.RFC3339)
}
if resolutionNotes.Valid {
errorData["resolution_notes"] = resolutionNotes.String
}
errors = append(errors, errorData)
}
// Get stats
stats := map[string]interface{}{
"total_errors": totalErrors,
"unresolved_errors": 0, // Will be calculated separately
}
// Get unresolved errors count
var unresolvedCount int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM notification_errors WHERE resolved_at IS NULL").Scan(&unresolvedCount)
if err != nil {
s.logger.Error(ctx, "Failed to get unresolved errors count", err, map[string]interface{}{})
} else {
stats["unresolved_errors"] = unresolvedCount
}
// Get errors by type
errorsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT error_type, COUNT(*) FROM notification_errors GROUP BY error_type")
if err != nil {
s.logger.Error(ctx, "Failed to get errors by type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var errorType string
var count int
if err := rows.Scan(&errorType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan errors by type", err, map[string]interface{}{})
continue
}
errorsByType[errorType] = count
}
stats["errors_by_type"] = errorsByType
}
// Get errors by notification type
errorsByNotificationType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT notification_type, COUNT(*) FROM notification_errors GROUP BY notification_type")
if err != nil {
s.logger.Error(ctx, "Failed to get errors by notification type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan errors by notification type", err, map[string]interface{}{})
continue
}
errorsByNotificationType[notificationType] = count
}
stats["errors_by_notification_type"] = errorsByNotificationType
}
pagination := map[string]interface{}{
"page": page,
"page_size": pageSize,
"total": totalErrors,
"total_pages": totalPages,
}
s.logger.Debug(ctx, "Retrieved notification errors", map[string]interface{}{
"count": len(errors), "page": page, "total": totalErrors,
})
return errors, pagination, stats, nil
}
// GetUpcomingNotifications returns paginated upcoming notifications with filtering
func (s *WorkerService) GetUpcomingNotifications(ctx context.Context, page, pageSize int, notificationType, status, scheduledAfter, scheduledBefore string) (result0 []map[string]interface{}, result1, result2 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_upcoming_notifications",
attribute.Int("page", page),
attribute.Int("page_size", pageSize),
attribute.String("notification_type", notificationType),
attribute.String("status", status),
attribute.String("scheduled_after", scheduledAfter),
attribute.String("scheduled_before", scheduledBefore),
)
defer observability.FinishSpan(span, &err)
// Build WHERE clause
whereConditions := []string{}
args := []interface{}{}
argIndex := 1
if notificationType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("notification_type = $%d", argIndex))
args = append(args, notificationType)
argIndex++
}
if status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argIndex))
args = append(args, status)
argIndex++
}
if scheduledAfter != "" {
whereConditions = append(whereConditions, fmt.Sprintf("scheduled_for >= $%d", argIndex))
args = append(args, scheduledAfter)
argIndex++
}
if scheduledBefore != "" {
whereConditions = append(whereConditions, fmt.Sprintf("scheduled_for <= $%d", argIndex))
args = append(args, scheduledBefore)
argIndex++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
var totalNotifications int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM upcoming_notifications %s", whereClause)
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get total upcoming notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get total upcoming notifications")
}
// Calculate pagination
offset := (page - 1) * pageSize
totalPages := (totalNotifications + pageSize - 1) / pageSize
// Get notifications with pagination
args = append(args, pageSize, offset)
query := fmt.Sprintf(`
SELECT un.id, un.user_id, u.username, u.email, un.notification_type,
un.scheduled_for, un.status, un.created_at
FROM upcoming_notifications un
LEFT JOIN users u ON un.user_id = u.id
%s
ORDER BY un.scheduled_for ASC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
s.logger.Error(ctx, "Failed to get upcoming notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get upcoming notifications")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
var notifications []map[string]interface{}
for rows.Next() {
var notification map[string]interface{}
var id, userID int
var username, notificationType, status string
var scheduledFor, createdAt time.Time
var email sql.NullString
err := rows.Scan(&id, &userID, &username, &email, ¬ificationType, &scheduledFor, &status, &createdAt)
if err != nil {
s.logger.Error(ctx, "Failed to scan upcoming notification", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to scan upcoming notification")
}
notification = map[string]interface{}{
"id": id,
"user_id": userID,
"username": username,
"notification_type": notificationType,
"scheduled_for": scheduledFor.Format(time.RFC3339),
"status": status,
"created_at": createdAt.Format(time.RFC3339),
}
if email.Valid {
notification["email_address"] = email.String
} else {
notification["email_address"] = ""
}
notifications = append(notifications, notification)
}
// Get stats
stats := map[string]interface{}{
"total_pending": 0,
"total_scheduled_today": 0,
"total_scheduled_this_week": 0,
}
// Get total pending
var totalPending int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM upcoming_notifications WHERE status = 'pending'").Scan(&totalPending)
if err != nil {
s.logger.Error(ctx, "Failed to get total pending", err, map[string]interface{}{})
} else {
stats["total_pending"] = totalPending
}
// Get scheduled today
var scheduledToday int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM upcoming_notifications
WHERE status = 'pending' AND DATE(scheduled_for) = CURRENT_DATE
`).Scan(&scheduledToday)
if err != nil {
s.logger.Error(ctx, "Failed to get scheduled today", err, map[string]interface{}{})
} else {
stats["total_scheduled_today"] = scheduledToday
}
// Get scheduled this week
var scheduledThisWeek int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM upcoming_notifications
WHERE status = 'pending' AND scheduled_for >= DATE_TRUNC('week', CURRENT_DATE)
`).Scan(&scheduledThisWeek)
if err != nil {
s.logger.Error(ctx, "Failed to get scheduled this week", err, map[string]interface{}{})
} else {
stats["total_scheduled_this_week"] = scheduledThisWeek
}
// Get notifications by type
notificationsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT notification_type, COUNT(*) FROM upcoming_notifications GROUP BY notification_type")
if err != nil {
s.logger.Error(ctx, "Failed to get notifications by type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan notifications by type", err, map[string]interface{}{})
continue
}
notificationsByType[notificationType] = count
}
stats["notifications_by_type"] = notificationsByType
}
pagination := map[string]interface{}{
"page": page,
"page_size": pageSize,
"total": totalNotifications,
"total_pages": totalPages,
}
s.logger.Debug(ctx, "Retrieved upcoming notifications", map[string]interface{}{
"count": len(notifications), "page": page, "total": totalNotifications,
})
return notifications, pagination, stats, nil
}
// GetSentNotifications returns paginated sent notifications with filtering
func (s *WorkerService) GetSentNotifications(ctx context.Context, page, pageSize int, notificationType, status, sentAfter, sentBefore string) (result0 []map[string]interface{}, result1, result2 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_sent_notifications",
attribute.Int("page", page),
attribute.Int("page_size", pageSize),
attribute.String("notification_type", notificationType),
attribute.String("status", status),
attribute.String("sent_after", sentAfter),
attribute.String("sent_before", sentBefore),
)
defer observability.FinishSpan(span, &err)
// Build WHERE clause
whereConditions := []string{}
args := []interface{}{}
argIndex := 1
if notificationType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("notification_type = $%d", argIndex))
args = append(args, notificationType)
argIndex++
}
if status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argIndex))
args = append(args, status)
argIndex++
}
if sentAfter != "" {
whereConditions = append(whereConditions, fmt.Sprintf("sent_at >= $%d", argIndex))
args = append(args, sentAfter)
argIndex++
}
if sentBefore != "" {
whereConditions = append(whereConditions, fmt.Sprintf("sent_at <= $%d", argIndex))
args = append(args, sentBefore)
argIndex++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
var totalNotifications int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sent_notifications %s", whereClause)
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get total sent notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get total sent notifications")
}
// Calculate pagination
offset := (page - 1) * pageSize
totalPages := (totalNotifications + pageSize - 1) / pageSize
// Get notifications with pagination
args = append(args, pageSize, offset)
query := fmt.Sprintf(`
SELECT sn.id, sn.user_id, u.username, u.email, sn.notification_type,
sn.subject, sn.template_name, sn.sent_at, sn.status, sn.error_message, sn.retry_count
FROM sent_notifications sn
LEFT JOIN users u ON sn.user_id = u.id
%s
ORDER BY sn.sent_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
s.logger.Error(ctx, "Failed to get sent notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get sent notifications")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
var notifications []map[string]interface{}
for rows.Next() {
var notification map[string]interface{}
var id, userID int
var username, notificationType, subject, templateName, status string
var sentAt time.Time
var errorMessage sql.NullString
var retryCount int
var email sql.NullString
err := rows.Scan(&id, &userID, &username, &email, ¬ificationType, &subject, &templateName, &sentAt, &status, &errorMessage, &retryCount)
if err != nil {
s.logger.Error(ctx, "Failed to scan sent notification", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to scan sent notification")
}
notification = map[string]interface{}{
"id": id,
"user_id": userID,
"username": username,
"notification_type": notificationType,
"subject": subject,
"template_name": templateName,
"sent_at": sentAt.Format(time.RFC3339),
"status": status,
"retry_count": retryCount,
}
if email.Valid {
notification["email_address"] = email.String
} else {
notification["email_address"] = ""
}
if errorMessage.Valid {
notification["error_message"] = errorMessage.String
}
notifications = append(notifications, notification)
}
// Get stats
stats := map[string]interface{}{
"total_sent": 0,
"total_failed": 0,
"success_rate": 0.0,
"sent_today": 0,
"sent_this_week": 0,
}
// Get total sent
var totalSent int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sent_notifications WHERE status = 'sent'").Scan(&totalSent)
if err != nil {
s.logger.Error(ctx, "Failed to get total sent", err, map[string]interface{}{})
} else {
stats["total_sent"] = totalSent
}
// Get total failed
var totalFailed int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sent_notifications WHERE status = 'failed'").Scan(&totalFailed)
if err != nil {
s.logger.Error(ctx, "Failed to get total failed", err, map[string]interface{}{})
} else {
stats["total_failed"] = totalFailed
}
// Calculate success rate
if totalSent+totalFailed > 0 {
stats["success_rate"] = float64(totalSent) / float64(totalSent+totalFailed)
}
// Get sent today
var sentToday int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND DATE(sent_at) = CURRENT_DATE
`).Scan(&sentToday)
if err != nil {
s.logger.Error(ctx, "Failed to get sent today", err, map[string]interface{}{})
} else {
stats["sent_today"] = sentToday
}
// Get sent this week
var sentThisWeek int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND sent_at >= DATE_TRUNC('week', CURRENT_DATE)
`).Scan(&sentThisWeek)
if err != nil {
s.logger.Error(ctx, "Failed to get sent this week", err, map[string]interface{}{})
} else {
stats["sent_this_week"] = sentThisWeek
}
// Get notifications by type
notificationsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT notification_type, COUNT(*) FROM sent_notifications GROUP BY notification_type")
if err != nil {
s.logger.Error(ctx, "Failed to get notifications by type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan notifications by type", err, map[string]interface{}{})
continue
}
notificationsByType[notificationType] = count
}
stats["notifications_by_type"] = notificationsByType
}
pagination := map[string]interface{}{
"page": page,
"page_size": pageSize,
"total": totalNotifications,
"total_pages": totalPages,
}
s.logger.Debug(ctx, "Retrieved sent notifications", map[string]interface{}{
"count": len(notifications), "page": page, "total": totalNotifications,
})
return notifications, pagination, stats, nil
}
// CreateTestSentNotification creates a test sent notification for testing purposes
func (s *WorkerService) CreateTestSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) error {
ctx, span := observability.TraceWorkerFunction(ctx, "create_test_sent_notification",
attribute.Int("user.id", userID),
attribute.String("notification.type", notificationType),
attribute.String("notification.status", status),
)
defer span.End()
query := `
INSERT INTO sent_notifications (user_id, notification_type, subject, template_name, sent_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := s.db.ExecContext(ctx, query, userID, notificationType, subject, templateName, time.Now(), status, errorMessage)
if err != nil {
span.RecordError(err)
s.logger.Error(ctx, "Failed to create test sent notification", err, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return contextutils.WrapError(err, "failed to create test sent notification")
}
s.logger.Info(ctx, "Created test sent notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return nil
}
// Package contextutils provides error handling utilities and standardized error types
// for consistent error management across the quiz application.
package contextutils
import (
"context"
"fmt"
"strings"
)
// ErrorCode represents a standardized error code for API responses
type ErrorCode string
const (
// Database error codes
// ErrorCodeDatabaseConnection indicates a database connection error
ErrorCodeDatabaseConnection ErrorCode = "DATABASE_CONNECTION_ERROR"
// ErrorCodeDatabaseQuery indicates a database query error
ErrorCodeDatabaseQuery ErrorCode = "DATABASE_QUERY_ERROR"
// ErrorCodeDatabaseTransaction indicates a database transaction error
ErrorCodeDatabaseTransaction ErrorCode = "DATABASE_TRANSACTION_ERROR"
// ErrorCodeRecordNotFound indicates that a requested record was not found
ErrorCodeRecordNotFound ErrorCode = "RECORD_NOT_FOUND"
// ErrorCodeRecordExists indicates that a record already exists (duplicate key)
ErrorCodeRecordExists ErrorCode = "RECORD_ALREADY_EXISTS"
// ErrorCodeForeignKeyViolation indicates a foreign key constraint violation
ErrorCodeForeignKeyViolation ErrorCode = "FOREIGN_KEY_VIOLATION"
// Validation error codes
// ErrorCodeInvalidInput indicates that the provided input is invalid
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
// ErrorCodeMissingRequired indicates that a required field is missing
ErrorCodeMissingRequired ErrorCode = "MISSING_REQUIRED_FIELD"
// ErrorCodeInvalidFormat indicates that the input format is invalid
ErrorCodeInvalidFormat ErrorCode = "INVALID_FORMAT"
// ErrorCodeValidationFailed indicates that validation has failed
ErrorCodeValidationFailed ErrorCode = "VALIDATION_FAILED"
// Authentication error codes
// ErrorCodeUnauthorized indicates that the user is not authorized
ErrorCodeUnauthorized ErrorCode = "UNAUTHORIZED"
// ErrorCodeForbidden indicates that the user is forbidden from accessing the resource
ErrorCodeForbidden ErrorCode = "FORBIDDEN"
// ErrorCodeInvalidCredentials indicates that the provided credentials are invalid
ErrorCodeInvalidCredentials ErrorCode = "INVALID_CREDENTIALS"
// ErrorCodeSessionExpired indicates that the user session has expired
ErrorCodeSessionExpired ErrorCode = "SESSION_EXPIRED"
// Service error codes
// ErrorCodeServiceUnavailable indicates that the service is temporarily unavailable
ErrorCodeServiceUnavailable ErrorCode = "SERVICE_UNAVAILABLE"
// ErrorCodeTimeout indicates that a request has timed out
ErrorCodeTimeout ErrorCode = "REQUEST_TIMEOUT"
// ErrorCodeRateLimit indicates that the rate limit has been exceeded
ErrorCodeRateLimit ErrorCode = "RATE_LIMIT_EXCEEDED"
// ErrorCodeQuotaExceeded indicates that the usage quota has been exceeded
ErrorCodeQuotaExceeded ErrorCode = "QUOTA_EXCEEDED"
// ErrorCodeInternalError indicates an internal server error
ErrorCodeInternalError ErrorCode = "INTERNAL_SERVER_ERROR"
// ErrorCodeAssignmentNotFound indicates that a question assignment was not found
ErrorCodeAssignmentNotFound ErrorCode = "ASSIGNMENT_NOT_FOUND"
// ErrorCodeConflict indicates that an operation conflicts with the current state
ErrorCodeConflict ErrorCode = "CONFLICT"
// Question error codes
// ErrorCodeTimestampMissingTimezone indicates that a timestamp is missing timezone information
ErrorCodeTimestampMissingTimezone ErrorCode = "TIMESTAMP_MISSING_TIMEZONE"
// ErrorCodeNoQuestionsAvailable indicates that no questions are available
ErrorCodeNoQuestionsAvailable ErrorCode = "NO_QUESTIONS_AVAILABLE"
// ErrorCodeQuestionAlreadyAnswered indicates that the question has already been answered
ErrorCodeQuestionAlreadyAnswered ErrorCode = "QUESTION_ALREADY_ANSWERED"
// ErrorCodeQuestionNotFound indicates that the requested question was not found
ErrorCodeQuestionNotFound ErrorCode = "QUESTION_NOT_FOUND"
// ErrorCodeInvalidAnswerIndex indicates that the answer index is invalid
ErrorCodeInvalidAnswerIndex ErrorCode = "INVALID_ANSWER_INDEX"
// ErrorCodeGenerationLimitReached indicates that the daily generation limit has been reached
ErrorCodeGenerationLimitReached ErrorCode = "GENERATION_LIMIT_REACHED"
// AI Service error codes
// ErrorCodeAIProviderUnavailable indicates that the AI provider is unavailable
ErrorCodeAIProviderUnavailable ErrorCode = "AI_PROVIDER_UNAVAILABLE"
// ErrorCodeAIRequestFailed indicates that the AI request failed
ErrorCodeAIRequestFailed ErrorCode = "AI_REQUEST_FAILED"
// ErrorCodeAIResponseInvalid indicates that the AI response is invalid
ErrorCodeAIResponseInvalid ErrorCode = "AI_RESPONSE_INVALID"
// ErrorCodeAIConfigInvalid indicates that the AI configuration is invalid
ErrorCodeAIConfigInvalid ErrorCode = "AI_CONFIG_INVALID"
// OAuth error codes
// ErrorCodeOAuthCodeExpired indicates that the OAuth authorization code has expired
ErrorCodeOAuthCodeExpired ErrorCode = "OAUTH_CODE_EXPIRED"
// ErrorCodeOAuthStateMismatch indicates that the OAuth state parameter does not match
ErrorCodeOAuthStateMismatch ErrorCode = "OAUTH_STATE_MISMATCH"
// ErrorCodeOAuthProviderError indicates an error from the OAuth provider
ErrorCodeOAuthProviderError ErrorCode = "OAUTH_PROVIDER_ERROR"
)
// SeverityLevel represents the severity of an error for logging and monitoring
type SeverityLevel string
const (
// SeverityDebug indicates debug-level errors for development
SeverityDebug SeverityLevel = "debug"
// SeverityInfo indicates informational errors
SeverityInfo SeverityLevel = "info"
// SeverityWarn indicates warning-level errors
SeverityWarn SeverityLevel = "warn"
// SeverityError indicates error-level issues
SeverityError SeverityLevel = "error"
// SeverityFatal indicates fatal errors that require immediate attention
SeverityFatal SeverityLevel = "fatal"
)
// AppError represents a structured error with code, severity, and context
type AppError struct {
Code ErrorCode
Severity SeverityLevel
Message string
Details string
Cause error
}
// Error implements the error interface
func (e *AppError) Error() string {
if e.Details != "" {
return fmt.Sprintf("%s: %s - %s", e.Code, e.Message, e.Details)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
// Unwrap returns the underlying cause error
func (e *AppError) Unwrap() error {
return e.Cause
}
// Is implements error comparison for errors.Is
func (e *AppError) Is(target error) bool {
if appErr, ok := target.(*AppError); ok {
return e.Code == appErr.Code
}
return false
}
// Error types for consistent error handling with associated codes and severity
var (
// Database errors
ErrDatabaseConnection = &AppError{
Code: ErrorCodeDatabaseConnection,
Severity: SeverityError,
Message: "Database connection failed",
}
ErrDatabaseQuery = &AppError{
Code: ErrorCodeDatabaseQuery,
Severity: SeverityError,
Message: "Database query failed",
}
ErrDatabaseTransaction = &AppError{
Code: ErrorCodeDatabaseTransaction,
Severity: SeverityError,
Message: "Database transaction failed",
}
ErrRecordNotFound = &AppError{
Code: ErrorCodeRecordNotFound,
Severity: SeverityInfo,
Message: "Record not found",
}
ErrRecordExists = &AppError{
Code: ErrorCodeRecordExists,
Severity: SeverityInfo,
Message: "Record already exists",
}
ErrForeignKeyViolation = &AppError{
Code: ErrorCodeForeignKeyViolation,
Severity: SeverityError,
Message: "Foreign key constraint violation",
}
// Validation errors
ErrInvalidInput = &AppError{
Code: ErrorCodeInvalidInput,
Severity: SeverityWarn,
Message: "Invalid input",
}
ErrMissingRequired = &AppError{
Code: ErrorCodeMissingRequired,
Severity: SeverityWarn,
Message: "Missing required field",
}
ErrInvalidFormat = &AppError{
Code: ErrorCodeInvalidFormat,
Severity: SeverityWarn,
Message: "Invalid format",
}
ErrValidationFailed = &AppError{
Code: ErrorCodeValidationFailed,
Severity: SeverityWarn,
Message: "Validation failed",
}
// Authentication errors
ErrUnauthorized = &AppError{
Code: ErrorCodeUnauthorized,
Severity: SeverityWarn,
Message: "Unauthorized",
}
ErrForbidden = &AppError{
Code: ErrorCodeForbidden,
Severity: SeverityWarn,
Message: "Forbidden",
}
ErrInvalidCredentials = &AppError{
Code: ErrorCodeInvalidCredentials,
Severity: SeverityWarn,
Message: "Invalid credentials",
}
ErrSessionExpired = &AppError{
Code: ErrorCodeSessionExpired,
Severity: SeverityInfo,
Message: "Session expired",
}
// Service errors
ErrServiceUnavailable = &AppError{
Code: ErrorCodeServiceUnavailable,
Severity: SeverityError,
Message: "Service unavailable",
}
ErrTimeout = &AppError{
Code: ErrorCodeTimeout,
Severity: SeverityWarn,
Message: "Request timeout",
}
ErrRateLimit = &AppError{
Code: ErrorCodeRateLimit,
Severity: SeverityWarn,
Message: "Rate limit exceeded",
}
ErrQuotaExceeded = &AppError{
Code: ErrorCodeQuotaExceeded,
Severity: SeverityWarn,
Message: "Usage quota exceeded",
}
ErrInternalError = &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: "Internal server error",
}
ErrAssignmentNotFound = &AppError{
Code: ErrorCodeAssignmentNotFound,
Severity: SeverityInfo,
Message: "Assignment not found",
}
ErrConflict = &AppError{
Code: ErrorCodeConflict,
Severity: SeverityWarn,
Message: "Operation conflicts with current state",
}
// Question errors
ErrTimestampMissingTimezone = &AppError{
Code: ErrorCodeTimestampMissingTimezone,
Severity: SeverityError,
Message: "Timestamp missing timezone",
}
ErrNoQuestionsAvailable = &AppError{
Code: ErrorCodeNoQuestionsAvailable,
Severity: SeverityInfo,
Message: "No questions available for assignment",
}
ErrQuestionAlreadyAnswered = &AppError{
Code: ErrorCodeQuestionAlreadyAnswered,
Severity: SeverityInfo,
Message: "Question already answered",
}
ErrQuestionNotFound = &AppError{
Code: ErrorCodeQuestionNotFound,
Severity: SeverityInfo,
Message: "Question not found",
}
ErrInvalidAnswerIndex = &AppError{
Code: ErrorCodeInvalidAnswerIndex,
Severity: SeverityWarn,
Message: "Invalid answer index",
}
ErrGenerationLimitReached = &AppError{
Code: ErrorCodeGenerationLimitReached,
Severity: SeverityInfo,
Message: "Daily generation limit reached",
}
// AI Service errors
ErrAIProviderUnavailable = &AppError{
Code: ErrorCodeAIProviderUnavailable,
Severity: SeverityError,
Message: "AI provider unavailable",
}
ErrAIRequestFailed = &AppError{
Code: ErrorCodeAIRequestFailed,
Severity: SeverityError,
Message: "AI request failed",
}
ErrAIResponseInvalid = &AppError{
Code: ErrorCodeAIResponseInvalid,
Severity: SeverityError,
Message: "AI response invalid",
}
ErrAIConfigInvalid = &AppError{
Code: ErrorCodeAIConfigInvalid,
Severity: SeverityError,
Message: "AI configuration invalid",
}
// OAuth errors
ErrOAuthCodeExpired = &AppError{
Code: ErrorCodeOAuthCodeExpired,
Severity: SeverityWarn,
Message: "OAuth code expired",
}
ErrOAuthStateMismatch = &AppError{
Code: ErrorCodeOAuthStateMismatch,
Severity: SeverityError,
Message: "OAuth state mismatch",
}
ErrOAuthProviderError = &AppError{
Code: ErrorCodeOAuthProviderError,
Severity: SeverityError,
Message: "OAuth provider error",
}
)
// NewAppError creates a new AppError with the specified code, severity, message and details
func NewAppError(code ErrorCode, severity SeverityLevel, message, details string) *AppError {
return &AppError{
Code: code,
Severity: severity,
Message: message,
Details: details,
}
}
// NewAppErrorWithCause creates a new AppError with an underlying cause
func NewAppErrorWithCause(code ErrorCode, severity SeverityLevel, message, details string, cause error) *AppError {
return &AppError{
Code: code,
Severity: severity,
Message: message,
Details: details,
Cause: cause,
}
}
// WrapError wraps an error with additional context, preserving AppError structure if possible
func WrapError(err error, context string) error {
if err == nil {
return nil
}
// If it's already an AppError, wrap it with additional details
if appErr, ok := err.(*AppError); ok {
return &AppError{
Code: appErr.Code,
Severity: appErr.Severity,
Message: context,
Details: appErr.Error(),
Cause: appErr,
}
}
// For regular errors, create a generic internal error wrapper
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: context,
Details: err.Error(),
Cause: err,
}
}
// WrapErrorf wraps an error with formatted context, preserving AppError structure if possible
func WrapErrorf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
// Handle %w verb for error wrapping by using fmt.Errorf
if strings.Contains(format, "%w") {
// Use fmt.Errorf to properly handle %w verb
wrappedErr := fmt.Errorf(format, args...)
// If it's already an AppError, wrap it with the formatted message
if appErr, ok := err.(*AppError); ok {
return &AppError{
Code: appErr.Code,
Severity: appErr.Severity,
Message: wrappedErr.Error(),
Details: appErr.Error(),
Cause: wrappedErr,
}
}
// For regular errors, wrap with the formatted error
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: wrappedErr.Error(),
Details: err.Error(),
Cause: wrappedErr,
}
}
// If it's already an AppError, wrap it with additional details
if appErr, ok := err.(*AppError); ok {
context := fmt.Sprintf(format, args...)
return &AppError{
Code: appErr.Code,
Severity: appErr.Severity,
Message: context,
Details: appErr.Error(),
Cause: appErr,
}
}
// For regular errors, create a generic internal error wrapper
context := fmt.Sprintf(format, args...)
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: context,
Details: err.Error(),
Cause: err,
}
}
// ErrorWithContextf creates a new error with formatted context
func ErrorWithContextf(format string, args ...interface{}) error {
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: fmt.Sprintf(format, args...),
}
}
// IsError checks if an error matches a specific AppError type
func IsError(err error, target *AppError) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == target.Code
}
return false
}
// AsError attempts to convert an error to an AppError
func AsError(err error, target **AppError) bool {
if appErr, ok := err.(*AppError); ok {
*target = appErr
return true
}
return false
}
// GetErrorCode returns the error code from an error if it's an AppError, otherwise returns a default code
func GetErrorCode(err error) ErrorCode {
if appErr, ok := err.(*AppError); ok {
return appErr.Code
}
return ErrorCodeInternalError
}
// GetErrorSeverity returns the severity level from an error if it's an AppError, otherwise returns error
func GetErrorSeverity(err error) SeverityLevel {
if appErr, ok := err.(*AppError); ok {
return appErr.Severity
}
return SeverityError
}
// IsRetryable determines if an error should be retried based on its type and severity
func IsRetryable(err error) bool {
if appErr, ok := err.(*AppError); ok {
// Only retry certain types of errors that are likely transient
switch appErr.Code {
case ErrorCodeTimeout, ErrorCodeServiceUnavailable, ErrorCodeDatabaseConnection:
return appErr.Severity != SeverityFatal
}
}
return false
}
// GetErrorLocalizedMessage returns a localized message for the error
func GetErrorLocalizedMessage(err error, locale string) string {
if appErr, ok := err.(*AppError); ok {
return GetLocalizedMessageWithDetails(appErr.Code, ParseLocale(locale), appErr.Details)
}
return "An error occurred"
}
// ToJSON converts an AppError to a JSON-serializable structure for API responses
func (e *AppError) ToJSON() map[string]interface{} {
result := map[string]interface{}{
"code": string(e.Code),
"message": e.Message,
"severity": string(e.Severity),
"error": e.Message, // Include error field for backward compatibility
}
if e.Details != "" {
result["details"] = e.Details
}
// Add retryable information
result["retryable"] = IsRetryable(e)
if e.Cause != nil {
// Only include cause in debug mode or for certain error types
switch e.Severity {
case SeverityError, SeverityFatal:
result["cause"] = e.Cause.Error()
}
}
return result
}
// ToJSONWithLocale converts an AppError to a JSON-serializable structure with localized messages
func (e *AppError) ToJSONWithLocale(locale string) map[string]interface{} {
result := e.ToJSON()
// Replace the message with localized version and update error field too
localizedMessage := GetLocalizedMessage(e.Code, ParseLocale(locale))
result["message"] = localizedMessage
result["error"] = localizedMessage // Keep error field in sync
return result
}
// ContextKey represents a context key type for passing values through context
type ContextKey string
const (
// UserIDKey is used to store user ID in context for usage tracking
UserIDKey ContextKey = "userID"
// APIKeyIDKey is used to store API key ID in context for usage tracking
APIKeyIDKey ContextKey = "apiKeyID"
)
// GetUserIDFromContext extracts the user ID from context, returning 0 if not found
func GetUserIDFromContext(ctx context.Context) int {
if userID, ok := ctx.Value(UserIDKey).(int); ok {
return userID
}
return 0 // Default fallback
}
// GetAPIKeyIDFromContext extracts the API key ID from context, returning nil if not found
func GetAPIKeyIDFromContext(ctx context.Context) *int {
if apiKeyID, ok := ctx.Value(APIKeyIDKey).(*int); ok {
return apiKeyID
}
return nil // Default fallback
}
// WithUserID returns a new context with the user ID set
func WithUserID(ctx context.Context, userID int) context.Context {
return context.WithValue(ctx, UserIDKey, userID)
}
// WithAPIKeyID returns a new context with the API key ID set
func WithAPIKeyID(ctx context.Context, apiKeyID int) context.Context {
return context.WithValue(ctx, APIKeyIDKey, &apiKeyID)
}
package contextutils
import (
"encoding/json"
"fmt"
"strings"
)
// Locale represents a language locale (e.g., "en", "es", "fr")
type Locale string
const (
// LocaleEnglish represents English language
LocaleEnglish Locale = "en"
// LocaleSpanish represents Spanish language
LocaleSpanish Locale = "es"
// LocaleFrench represents French language
LocaleFrench Locale = "fr"
// LocaleGerman represents German language
LocaleGerman Locale = "de"
// LocaleItalian represents Italian language
LocaleItalian Locale = "it"
)
// LocalizedMessages contains localized error messages for different locales
type LocalizedMessages struct {
messages map[ErrorCode]map[Locale]string
}
// NewLocalizedMessages creates a new instance of localized messages
func NewLocalizedMessages() *LocalizedMessages {
return &LocalizedMessages{
messages: make(map[ErrorCode]map[Locale]string),
}
}
// AddMessage adds a localized message for a specific error code and locale
func (lm *LocalizedMessages) AddMessage(code ErrorCode, locale Locale, message string) {
if lm.messages[code] == nil {
lm.messages[code] = make(map[Locale]string)
}
lm.messages[code][locale] = message
}
// GetMessage returns the localized message for an error code and locale
func (lm *LocalizedMessages) GetMessage(code ErrorCode, locale Locale) string {
// Try to get the message for the specific locale
if localeMessages, exists := lm.messages[code]; exists {
if message, exists := localeMessages[locale]; exists {
return message
}
// Fallback to English if the specific locale doesn't have a message
if message, exists := localeMessages[LocaleEnglish]; exists {
return message
}
}
// Fallback to a default message
return getDefaultMessage(code)
}
// GetMessageWithDetails returns a localized message with additional details
func (lm *LocalizedMessages) GetMessageWithDetails(code ErrorCode, locale Locale, details string) string {
message := lm.GetMessage(code, locale)
if details != "" {
return fmt.Sprintf("%s: %s", message, details)
}
return message
}
// getDefaultMessage returns a default English message for error codes
func getDefaultMessage(code ErrorCode) string {
switch code {
case ErrorCodeDatabaseConnection:
return "Database connection failed"
case ErrorCodeDatabaseQuery:
return "Database query failed"
case ErrorCodeDatabaseTransaction:
return "Database transaction failed"
case ErrorCodeRecordNotFound:
return "Record not found"
case ErrorCodeRecordExists:
return "Record already exists"
case ErrorCodeForeignKeyViolation:
return "Foreign key constraint violation"
case ErrorCodeInvalidInput:
return "Invalid input"
case ErrorCodeMissingRequired:
return "Missing required field"
case ErrorCodeInvalidFormat:
return "Invalid format"
case ErrorCodeValidationFailed:
return "Validation failed"
case ErrorCodeUnauthorized:
return "Unauthorized access"
case ErrorCodeForbidden:
return "Access forbidden"
case ErrorCodeInvalidCredentials:
return "Invalid credentials"
case ErrorCodeSessionExpired:
return "Session expired"
case ErrorCodeServiceUnavailable:
return "Service temporarily unavailable"
case ErrorCodeTimeout:
return "Request timeout"
case ErrorCodeRateLimit:
return "Rate limit exceeded"
case ErrorCodeInternalError:
return "Internal server error"
case ErrorCodeAssignmentNotFound:
return "Assignment not found"
case ErrorCodeTimestampMissingTimezone:
return "Timestamp missing timezone"
case ErrorCodeNoQuestionsAvailable:
return "No questions available"
case ErrorCodeQuestionAlreadyAnswered:
return "Question already answered"
case ErrorCodeQuestionNotFound:
return "Question not found"
case ErrorCodeInvalidAnswerIndex:
return "Invalid answer index"
case ErrorCodeAIProviderUnavailable:
return "AI service unavailable"
case ErrorCodeAIRequestFailed:
return "AI request failed"
case ErrorCodeAIResponseInvalid:
return "AI response invalid"
case ErrorCodeAIConfigInvalid:
return "AI configuration invalid"
case ErrorCodeOAuthCodeExpired:
return "OAuth code expired"
case ErrorCodeOAuthStateMismatch:
return "OAuth state mismatch"
case ErrorCodeOAuthProviderError:
return "OAuth provider error"
default:
return "An error occurred"
}
}
// LoadMessagesFromJSON loads localized messages from a JSON structure
func (lm *LocalizedMessages) LoadMessagesFromJSON(jsonData string) error {
var data map[string]map[string]string
if err := json.Unmarshal([]byte(jsonData), &data); err != nil {
return WrapError(err, "failed to parse localization JSON")
}
for codeStr, localeMessages := range data {
code := ErrorCode(codeStr)
for localeStr, message := range localeMessages {
locale := Locale(localeStr)
lm.AddMessage(code, locale, message)
}
}
return nil
}
// GetSupportedLocales returns a list of supported locales
func (lm *LocalizedMessages) GetSupportedLocales() []Locale {
locales := make(map[Locale]bool)
for _, localeMessages := range lm.messages {
for locale := range localeMessages {
locales[locale] = true
}
}
result := make([]Locale, 0, len(locales))
for locale := range locales {
result = append(result, locale)
}
return result
}
// ParseLocale parses a locale string (e.g., "en-US", "fr-CA") and returns the language part
func ParseLocale(localeStr string) Locale {
// Handle locale formats like "en-US", "fr-CA", etc.
parts := strings.Split(localeStr, "-")
if len(parts) > 0 && parts[0] != "" {
return Locale(strings.ToLower(parts[0]))
}
return LocaleEnglish // Default fallback
}
// Global instance of localized messages
var globalLocalizedMessages = NewLocalizedMessages()
// init loads default localized messages
func init() {
// Load some basic localized messages
globalLocalizedMessages.AddMessage(ErrorCodeInvalidInput, LocaleSpanish, "Entrada invÃlida")
globalLocalizedMessages.AddMessage(ErrorCodeInvalidInput, LocaleFrench, "EntrÃe invalide")
globalLocalizedMessages.AddMessage(ErrorCodeInvalidInput, LocaleGerman, "UngÃltige Eingabe")
globalLocalizedMessages.AddMessage(ErrorCodeRecordNotFound, LocaleSpanish, "Registro no encontrado")
globalLocalizedMessages.AddMessage(ErrorCodeRecordNotFound, LocaleFrench, "Enregistrement non trouvÃ")
globalLocalizedMessages.AddMessage(ErrorCodeRecordNotFound, LocaleGerman, "Datensatz nicht gefunden")
globalLocalizedMessages.AddMessage(ErrorCodeUnauthorized, LocaleSpanish, "Acceso no autorizado")
globalLocalizedMessages.AddMessage(ErrorCodeUnauthorized, LocaleFrench, "AccÃs non autorisÃ")
globalLocalizedMessages.AddMessage(ErrorCodeUnauthorized, LocaleGerman, "Unbefugter Zugriff")
globalLocalizedMessages.AddMessage(ErrorCodeInternalError, LocaleSpanish, "Error interno del servidor")
globalLocalizedMessages.AddMessage(ErrorCodeInternalError, LocaleFrench, "Erreur interne du serveur")
globalLocalizedMessages.AddMessage(ErrorCodeInternalError, LocaleGerman, "Interner Serverfehler")
}
// GetLocalizedMessage returns a localized error message using the global instance
func GetLocalizedMessage(code ErrorCode, locale Locale) string {
return globalLocalizedMessages.GetMessage(code, locale)
}
// GetLocalizedMessageWithDetails returns a localized error message with details
func GetLocalizedMessageWithDetails(code ErrorCode, locale Locale, details string) string {
return globalLocalizedMessages.GetMessageWithDetails(code, locale, details)
}
// SetGlobalLocalizedMessages sets the global localized messages instance
func SetGlobalLocalizedMessages(messages *LocalizedMessages) {
globalLocalizedMessages = messages
}
package contextutils
import (
"strings"
)
// MaskAPIKey masks an API key for logging purposes to prevent exposure
// Returns a masked version that shows only first 4 and last 4 characters
func MaskAPIKey(apiKey string) string {
if apiKey == "" {
return "[EMPTY]"
}
if len(apiKey) <= 8 {
return strings.Repeat("*", len(apiKey))
}
return apiKey[:4] + strings.Repeat("*", len(apiKey)-8) + apiKey[len(apiKey)-4:]
}
package contextutils
import (
"context"
"time"
"quizapp/internal/models"
)
// ParseDateInUserTimezone parses a YYYY-MM-DD date string in the user's timezone.
// The userLookup function is injected to fetch the user (to avoid tight coupling and enable testing).
// Returns the parsed time (in the location), the effective timezone name (or "UTC" on fallback), and an error.
// If the date format is invalid, the returned error will be wrapped with the message "invalid date format".
func ParseDateInUserTimezone(
ctx context.Context,
userID int,
dateStr string,
userLookup func(context.Context, int) (*models.User, error),
) (time.Time, string, error) {
user, err := userLookup(ctx, userID)
if err != nil {
return time.Time{}, "", err
}
timezone := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
// Fallback to UTC if invalid timezone
loc = time.UTC
timezone = "UTC"
}
date, err := time.ParseInLocation("2006-01-02", dateStr, loc)
if err != nil {
return time.Time{}, timezone, WrapError(err, "invalid date format")
}
return date, timezone, nil
}
// ConvertTimeToUserLocation converts the provided time to the user's timezone.
// Returns the converted time and the effective timezone name (or "UTC" on fallback).
func ConvertTimeToUserLocation(
ctx context.Context,
userID int,
t time.Time,
userLookup func(context.Context, int) (*models.User, error),
) (time.Time, string, error) {
user, err := userLookup(ctx, userID)
if err != nil {
return time.Time{}, "", err
}
timezone := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
loc = time.UTC
timezone = "UTC"
}
return t.In(loc), timezone, nil
}
// FormatTimeInUserTimezone formats the provided time in the user's timezone using the given layout.
// Returns the formatted string and the effective timezone name.
func FormatTimeInUserTimezone(
ctx context.Context,
userID int,
t time.Time,
layout string,
userLookup func(context.Context, int) (*models.User, error),
) (string, string, error) {
// If the stored timestamp is exactly midnight UTC with zero nanoseconds,
// it may be a date-only value (missing timezone). We only treat it as
// missing if the user has a configured timezone that is not UTC.
if t.Location() == time.UTC && t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 && t.Nanosecond() == 0 {
if userLookup != nil {
if u, err := userLookup(ctx, userID); err == nil && u != nil && u.Timezone.Valid && u.Timezone.String != "" && u.Timezone.String != "UTC" {
return "", "", ErrTimestampMissingTimezone
}
}
}
tt, tz, err := ConvertTimeToUserLocation(ctx, userID, t, userLookup)
if err != nil {
return "", tz, err
}
res := tt.Format(layout)
return res, tz, nil
}
// UserLocalDayRange returns the UTC start and end timestamps that cover the
// last `days` calendar days for the given user in their configured timezone.
// The range is [startUTC, endUTC) where startUTC is the start of the earliest
// local day at 00:00 and endUTC is the start of the day after "today" at 00:00
// in UTC. The userLookup function is used to fetch the user's timezone.
func UserLocalDayRange(ctx context.Context, userID, days int, userLookup func(context.Context, int) (*models.User, error)) (time.Time, time.Time, string, error) {
if days <= 0 {
days = 1
}
user, err := userLookup(ctx, userID)
if err != nil {
return time.Time{}, time.Time{}, "", err
}
timezone := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
loc = time.UTC
timezone = "UTC"
}
now := time.Now().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
startLocal := today.AddDate(0, 0, -(days - 1))
// start of the day after today
endLocal := today.Add(24 * time.Hour)
startUTC := startLocal.UTC()
endUTC := endLocal.UTC()
return startUTC, endUTC, timezone, nil
}
package contextutils
import (
"github.com/go-playground/validator/v10"
)
var validate = validator.New()
// IsValidEmail checks if an email address is valid using go-playground/validator
func IsValidEmail(email string) bool {
return validate.Var(email, "email") == nil
}
// Package worker contains the background worker responsible for generating
// and maintaining daily question assignments, scheduling generation jobs,
// and reporting worker health. The worker runs independently of HTTP
// request handling and interacts with the database, AI providers, and
// other internal services to keep question queues primed for users.
package worker
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"math"
"os"
"strconv"
"strings"
"sync"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/services/mailer"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// Status represents the current state of the worker
type Status struct {
IsRunning bool `json:"is_running"`
IsPaused bool `json:"is_paused"`
CurrentActivity string `json:"current_activity,omitempty"`
LastRunStart time.Time `json:"last_run_start"`
LastRunFinish time.Time `json:"last_run_finish"`
LastRunError string `json:"last_run_error,omitempty"`
NextRun time.Time `json:"next_run"`
}
// RunRecord tracks individual worker runs
type RunRecord struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration time.Duration `json:"duration"`
Status string `json:"status"` // Success, Failure
Details string `json:"details"`
}
// ActivityLog represents a single activity log entry
type ActivityLog struct {
Timestamp time.Time `json:"timestamp"`
Level string `json:"level"` // INFO, WARN, ERROR
Message string `json:"message"`
UserID *int `json:"user_id,omitempty"`
Username *string `json:"username,omitempty"`
}
// UserFailureInfo tracks failure information for exponential backoff
type UserFailureInfo struct {
ConsecutiveFailures int
LastFailureTime time.Time
NextRetryTime time.Time
}
// Config holds worker-specific configuration
type Config struct {
StartWorkerPaused bool
DailyHorizonDays int
}
// Worker manages AI question generation in the background
type Worker struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
learningService services.LearningServiceInterface
workerService services.WorkerServiceInterface
dailyQuestionService services.DailyQuestionServiceInterface
wordOfTheDayService services.WordOfTheDayServiceInterface
storyService services.StoryServiceInterface
emailService mailer.Mailer
hintService services.GenerationHintServiceInterface
translationCacheRepo services.TranslationCacheRepository
instance string
status Status
history []RunRecord
activityLogs []ActivityLog // Circular buffer for recent activity logs
mu sync.RWMutex
manualTrigger chan bool
cfg *config.Config
workerCfg Config
logger *observability.Logger
lastTranslationCleanup time.Time // Track last translation cache cleanup
translationCleanupMu sync.RWMutex
// Track failures for exponential backoff
userFailures map[int]*UserFailureInfo // userID -> failure info
failureMu sync.RWMutex // mutex for failure tracking
// Time function for testing - defaults to time.Now
timeNow func() time.Time
cancel context.CancelFunc // Added for cleanup
}
// cleanupTranslationCache removes expired translation cache entries once per day
func (w *Worker) cleanupTranslationCache(ctx context.Context) error {
ctx, span := otel.Tracer("worker").Start(ctx, "cleanupTranslationCache",
trace.WithAttributes(
attribute.String("worker.instance", w.instance),
),
)
defer span.End()
// Check if we've already cleaned up today
w.translationCleanupMu.Lock()
lastCleanup := w.lastTranslationCleanup
w.translationCleanupMu.Unlock()
now := w.timeNow()
// Only cleanup once per day (check if last cleanup was on a different day)
if !lastCleanup.IsZero() {
lastCleanupDay := lastCleanup.Truncate(24 * time.Hour)
todayDay := now.Truncate(24 * time.Hour)
if lastCleanupDay.Equal(todayDay) {
// Already cleaned up today
span.SetAttributes(
attribute.Bool("cleanup.skipped", true),
attribute.String("cleanup.last_run", lastCleanup.Format(time.RFC3339)),
)
return nil
}
}
w.logger.Info(ctx, "Cleaning up expired translation cache entries", map[string]interface{}{
"last_cleanup": lastCleanup,
})
count, err := w.translationCacheRepo.CleanupExpiredTranslations(ctx)
if err != nil {
span.RecordError(err)
span.SetAttributes(attribute.Bool("cleanup.success", false))
return contextutils.WrapError(err, "failed to cleanup expired translation cache entries")
}
// Update last cleanup time
w.translationCleanupMu.Lock()
w.lastTranslationCleanup = now
w.translationCleanupMu.Unlock()
span.SetAttributes(
attribute.Bool("cleanup.success", true),
attribute.Int64("cleanup.deleted_count", count),
)
w.logger.Info(ctx, "Translation cache cleanup completed", map[string]interface{}{
"deleted_count": count,
"instance": w.instance,
})
return nil
}
// checkForDailyReminders checks if any users need daily reminder emails
func (w *Worker) checkForDailyReminders(ctx context.Context) error {
ctx, span := otel.Tracer("worker").Start(ctx, "checkForDailyReminders",
trace.WithAttributes(
attribute.String("worker.instance", w.instance),
attribute.Bool("email.daily_reminder.enabled", w.cfg.Email.DailyReminder.Enabled),
attribute.Int("email.daily_reminder.hour", w.cfg.Email.DailyReminder.Hour),
attribute.Bool("email.enabled", w.cfg.Email.Enabled),
),
)
defer span.End()
if !w.cfg.Email.DailyReminder.Enabled {
w.logger.Info(ctx, "Daily reminders disabled, skipping", nil)
return nil
}
// Get current time in UTC
now := w.timeNow().UTC()
currentHour := now.Hour()
// Check if it's time to send reminders (default: 9 AM)
reminderHour := w.cfg.Email.DailyReminder.Hour
if currentHour != reminderHour {
span.SetAttributes(
attribute.Int("check.current_hour", currentHour),
attribute.Int("check.reminder_hour", reminderHour),
attribute.Bool("check.should_send", false),
attribute.String("check.reason", "wrong_hour"),
)
return nil
}
span.SetAttributes(
attribute.Int("check.current_hour", currentHour),
attribute.Int("check.reminder_hour", reminderHour),
attribute.Bool("check.should_send", true),
)
w.logger.Info(ctx, "Checking for users needing daily reminders", map[string]interface{}{
"reminder_hour": reminderHour,
})
// Get users who need daily reminders
users, err := w.getUsersNeedingDailyReminders(ctx)
if err != nil {
span.RecordError(err)
span.SetAttributes(
attribute.Int("users.total", 0),
attribute.Int("users.eligible", 0),
attribute.Int("reminders.sent", 0),
)
w.logger.Error(ctx, "Failed to get users needing daily reminders", err, nil)
return contextutils.WrapError(err, "failed to get users needing daily reminders")
}
span.SetAttributes(
attribute.Int("users.total", len(users)),
)
remindersSent := 0
failedReminders := 0
for _, user := range users {
// Record the sent notification
subject := "Time for your daily quiz! ð"
status := "sent"
errorMsg := ""
if err := w.emailService.SendDailyReminder(ctx, &user); err != nil {
failedReminders++
status = "failed"
errorMsg = err.Error()
w.logger.Error(ctx, "Failed to send daily reminder", err, map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
} else {
remindersSent++
}
// Record the sent notification in the database
if err := w.emailService.RecordSentNotification(ctx, user.ID, "daily_reminder", subject, "daily_reminder", status, errorMsg); err != nil {
w.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": user.ID,
})
}
// Update the last reminder sent timestamp for this user
if err := w.learningService.UpdateLastDailyReminderSent(ctx, user.ID); err != nil {
w.logger.Error(ctx, "Failed to update last daily reminder sent timestamp", err, map[string]interface{}{
"user_id": user.ID,
})
// Don't count this as a failed reminder since the email was sent successfully
}
}
span.SetAttributes(
attribute.Int("users.eligible", len(users)),
attribute.Int("reminders.sent", remindersSent),
attribute.Int("reminders.failed", failedReminders),
attribute.Float64("reminders.success_rate", float64(remindersSent)/float64(len(users))),
)
w.logger.Info(ctx, "Daily reminders processed", map[string]interface{}{
"total_users": len(users),
"reminders_sent": remindersSent,
"reminder_hour": reminderHour,
})
return nil
}
// getUsersNeedingDailyReminders returns users who should receive daily reminders
func (w *Worker) getUsersNeedingDailyReminders(ctx context.Context) ([]models.User, error) {
ctx, span := otel.Tracer("worker").Start(ctx, "getUsersNeedingDailyReminders")
defer span.End()
// Get all users and filter for those with email addresses and daily reminders enabled
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to get users")
}
var eligibleUsers []models.User
today := w.timeNow().UTC().Format("2006-01-02")
for _, user := range users {
// Check if user has email address
if !user.Email.Valid || user.Email.String == "" {
continue
}
// Get user's learning preferences to check daily reminder setting
prefs, err := w.learningService.GetUserLearningPreferences(ctx, user.ID)
if err != nil {
w.logger.Warn(ctx, "Failed to get user learning preferences for daily reminder check", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"error": err.Error(),
})
continue
}
// Check if daily reminders are enabled for this user
if prefs == nil || !prefs.DailyReminderEnabled {
continue
}
// Check if we've already sent a reminder today
if prefs.LastDailyReminderSent != nil {
lastReminderDate := prefs.LastDailyReminderSent.Format("2006-01-02")
if lastReminderDate == today {
continue
}
}
eligibleUsers = append(eligibleUsers, user)
}
w.logger.Info(ctx, "Found users eligible for daily reminders", map[string]interface{}{
"total_users": len(users),
"eligible_users": len(eligibleUsers),
})
return eligibleUsers, nil
}
// checkForDailyQuestionAssignments assigns daily questions to all eligible users
// This runs independently of email reminders to ensure users get daily questions
// even if they have email reminders disabled
func (w *Worker) checkForDailyQuestionAssignments(ctx context.Context) error {
ctx, span := observability.TraceWorkerFunction(ctx, "check_for_daily_question_assignments",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.logger.Info(ctx, "Checking for daily question assignments", map[string]interface{}{
"instance": w.instance,
})
// Get users who are eligible for daily questions
users, err := w.getUsersEligibleForDailyQuestions(ctx)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to get users eligible for daily questions", err, nil)
return contextutils.WrapError(err, "failed to get users eligible for daily questions")
}
if len(users) == 0 {
w.logger.Info(ctx, "No users eligible for daily question assignments", map[string]interface{}{
"instance": w.instance,
})
return nil
}
span.SetAttributes(
attribute.Int("users.total", len(users)),
)
successfulAssignments := 0
failedAssignments := 0
for _, user := range users {
// Get user's timezone, default to UTC if not set
timezone := "UTC"
if user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
// Get today's date in the user's timezone
loc, err := time.LoadLocation(timezone)
if err != nil {
w.logger.Warn(ctx, "Invalid timezone for user, using UTC", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"error": err.Error(),
})
loc = time.UTC
}
// Get today's date in the user's timezone
now := w.timeNow().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Assign daily questions for dates in [today .. today+N]
horizon := w.workerCfg.DailyHorizonDays
if horizon <= 0 {
// default to 2 days ahead when misconfigured or not set
horizon = 2
}
// Ensure the worker horizon covers the configured avoid window so
// that when future assignments are removed (e.g., after a correct
// submission) the worker run will top up missing slots. Use server
// config as the source of truth for the avoid window.
avoidDays := 7
if w.cfg != nil && w.cfg.Server.DailyRepeatAvoidDays > 0 {
avoidDays = w.cfg.Server.DailyRepeatAvoidDays
}
if horizon < avoidDays {
w.logger.Info(ctx, "Extending worker daily horizon to cover daily repeat avoid window", map[string]interface{}{
"old_horizon": horizon,
"new_horizon": avoidDays,
"user_id": user.ID,
})
horizon = avoidDays
}
for d := 0; d <= horizon; d++ {
target := today.AddDate(0, 0, d)
// Assign daily questions for target date in user's timezone
if err := w.dailyQuestionService.AssignDailyQuestions(ctx, user.ID, target); err != nil {
failedAssignments++
w.logger.Error(ctx, "Failed to assign daily questions", err, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"date": target.Format("2006-01-02"),
})
} else {
successfulAssignments++
}
}
}
span.SetAttributes(
attribute.Int("assignments.successful", successfulAssignments),
attribute.Int("assignments.failed", failedAssignments),
)
return nil
}
// getUsersEligibleForDailyQuestions returns users who should receive daily questions
// This is independent of email reminder preferences
func (w *Worker) getUsersEligibleForDailyQuestions(ctx context.Context) ([]models.User, error) {
ctx, span := otel.Tracer("worker").Start(ctx, "getUsersEligibleForDailyQuestions")
defer span.End()
// Get all users
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to get users")
}
var eligibleUsers []models.User
for _, user := range users {
// Check if user has language and level preferences set
if !user.PreferredLanguage.Valid || user.PreferredLanguage.String == "" {
w.logger.Debug(ctx, "User missing preferred language, skipping daily question assignment", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
continue
}
if !user.CurrentLevel.Valid || user.CurrentLevel.String == "" {
w.logger.Debug(ctx, "User missing current level, skipping daily question assignment", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
continue
}
// USers with AI disabled are not eligible for daily questions
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
w.logger.Debug(ctx, "User has AI disabled, skipping daily question assignment", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
continue
}
eligibleUsers = append(eligibleUsers, user)
}
w.logger.Info(ctx, "Found users eligible for daily questions", map[string]interface{}{
"total_users": len(users),
"eligible_users": len(eligibleUsers),
})
return eligibleUsers, nil
}
// checkForWordOfTheDayAssignments assigns word of the day to all eligible users
func (w *Worker) checkForWordOfTheDayAssignments(ctx context.Context) error {
ctx, span := observability.TraceWorkerFunction(ctx, "check_for_word_of_the_day_assignments",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.logger.Info(ctx, "Checking for word of the day assignments", map[string]interface{}{
"instance": w.instance,
})
// Get users who are eligible for word of the day
users, err := w.getUsersEligibleForWordOfTheDay(ctx)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to get users eligible for word of the day", err, nil)
return contextutils.WrapError(err, "failed to get users eligible for word of the day")
}
if len(users) == 0 {
w.logger.Info(ctx, "No users eligible for word of the day assignments", map[string]interface{}{
"instance": w.instance,
})
return nil
}
span.SetAttributes(
attribute.Int("users.total", len(users)),
)
successfulAssignments := 0
failedAssignments := 0
for _, user := range users {
// Get user's timezone, default to UTC if not set
timezone := "UTC"
if user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
// Get today's date in the user's timezone
loc, err := time.LoadLocation(timezone)
if err != nil {
w.logger.Warn(ctx, "Invalid timezone for user, using UTC", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"error": err.Error(),
})
loc = time.UTC
}
// Get today's date in the user's timezone
now := w.timeNow().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Idempotent: fetch existing or create if missing
_, err = w.wordOfTheDayService.GetWordOfTheDay(ctx, user.ID, today)
if err != nil {
// Treat no-available-word as a normal condition
if errors.Is(err, services.ErrNoSuitableWord) {
w.logger.Info(ctx, "No suitable word available for user today", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"date": today.Format("2006-01-02"),
})
continue
}
failedAssignments++
w.logger.Error(ctx, "Failed to assign word of the day", err, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"date": today.Format("2006-01-02"),
})
} else {
successfulAssignments++
}
}
span.SetAttributes(
attribute.Int("assignments.successful", successfulAssignments),
attribute.Int("assignments.failed", failedAssignments),
)
return nil
}
// getUsersEligibleForWordOfTheDay returns users who should receive word of the day
func (w *Worker) getUsersEligibleForWordOfTheDay(ctx context.Context) ([]models.User, error) {
ctx, span := otel.Tracer("worker").Start(ctx, "getUsersEligibleForWordOfTheDay")
defer span.End()
// Get all users
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to get users")
}
var eligibleUsers []models.User
for _, user := range users {
// Check if user has language and level preferences set
if !user.PreferredLanguage.Valid || user.PreferredLanguage.String == "" {
continue
}
if !user.CurrentLevel.Valid || user.CurrentLevel.String == "" {
continue
}
// Skip users with AI disabled
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
continue
}
eligibleUsers = append(eligibleUsers, user)
}
w.logger.Info(ctx, "Found users eligible for word of the day", map[string]interface{}{
"total_users": len(users),
"eligible_users": len(eligibleUsers),
})
return eligibleUsers, nil
}
// checkForWordOfTheDayEmails sends word of the day emails to eligible users
func (w *Worker) checkForWordOfTheDayEmails(ctx context.Context) error {
ctx, span := observability.TraceWorkerFunction(ctx, "check_for_word_of_the_day_emails",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
if !w.cfg.Email.DailyReminder.Enabled {
w.logger.Info(ctx, "Email disabled, skipping word of the day emails", nil)
return nil
}
// Get current time in UTC
now := w.timeNow().UTC()
currentHour := now.Hour()
// Send word of the day emails at the same hour as daily reminders (default: 9 AM)
reminderHour := w.cfg.Email.DailyReminder.Hour
if currentHour != reminderHour {
return nil
}
// Get users who should receive word of the day emails
users, err := w.getUsersNeedingWordOfTheDayEmails(ctx)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get users needing word of the day emails")
}
span.SetAttributes(
attribute.Int("users.total", len(users)),
)
emailsSent := 0
failedEmails := 0
for _, user := range users {
// Get user's timezone
timezone := "UTC"
if user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
loc = time.UTC
}
now := w.timeNow().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Get word of the day for today
word, err := w.wordOfTheDayService.GetWordOfTheDay(ctx, user.ID, today)
if err != nil {
failedEmails++
w.logger.Error(ctx, "Failed to get word of the day for email", err, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
continue
}
if word == nil {
// No word available, skip
continue
}
// Send email (convert mailer.Mailer to services.EmailServiceInterface)
emailSvc, ok := w.emailService.(services.EmailServiceInterface)
if !ok {
w.logger.Warn(ctx, "Email service does not support word of the day emails", map[string]interface{}{
"user_id": user.ID,
})
continue
}
if err := emailSvc.SendWordOfTheDayEmail(ctx, user.ID, today, word); err != nil {
failedEmails++
w.logger.Error(ctx, "Failed to send word of the day email", err, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
} else {
emailsSent++
}
}
span.SetAttributes(
attribute.Int("emails.sent", emailsSent),
attribute.Int("emails.failed", failedEmails),
)
return nil
}
// getUsersNeedingWordOfTheDayEmails returns users who should receive word of the day emails
func (w *Worker) getUsersNeedingWordOfTheDayEmails(ctx context.Context) ([]models.User, error) {
ctx, span := otel.Tracer("worker").Start(ctx, "getUsersNeedingWordOfTheDayEmails")
defer span.End()
// Get all users
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to get users")
}
var eligibleUsers []models.User
for _, user := range users {
// Check if user has email address
if !user.Email.Valid || user.Email.String == "" {
continue
}
// Check if word of the day emails are enabled for this user
if !user.WordOfDayEmailEnabled.Bool {
continue
}
eligibleUsers = append(eligibleUsers, user)
}
w.logger.Info(ctx, "Found users eligible for word of the day emails", map[string]interface{}{
"total_users": len(users),
"eligible_users": len(eligibleUsers),
})
return eligibleUsers, nil
}
// NewWorker creates a new Worker instance
func NewWorker(userService services.UserServiceInterface, questionService services.QuestionServiceInterface, aiService services.AIServiceInterface, learningService services.LearningServiceInterface, workerService services.WorkerServiceInterface, dailyQuestionService services.DailyQuestionServiceInterface, wordOfTheDayService services.WordOfTheDayServiceInterface, storyService services.StoryServiceInterface, emailService mailer.Mailer, hintService services.GenerationHintServiceInterface, translationCacheRepo services.TranslationCacheRepository, instance string, cfg *config.Config, logger *observability.Logger) *Worker {
if instance == "" {
instance = "default"
}
ctx, cancel := context.WithCancel(context.Background())
// Prefer value from config file when set (>0). If not set, default to 1.
dailyHorizon := cfg.Server.DailyHorizonDays
if dailyHorizon <= 0 {
dailyHorizon = 1
}
w := &Worker{
userService: userService,
questionService: questionService,
aiService: aiService,
learningService: learningService,
workerService: workerService,
dailyQuestionService: dailyQuestionService,
wordOfTheDayService: wordOfTheDayService,
storyService: storyService,
emailService: emailService,
hintService: hintService,
translationCacheRepo: translationCacheRepo,
instance: instance,
status: Status{IsRunning: false, CurrentActivity: "Initialized"},
history: make([]RunRecord, 0, cfg.Server.MaxHistory),
activityLogs: make([]ActivityLog, 0, cfg.Server.MaxActivityLogs),
manualTrigger: make(chan bool, 1),
cfg: cfg,
workerCfg: Config{StartWorkerPaused: getEnvBool("WORKER_START_PAUSED", false), DailyHorizonDays: dailyHorizon},
logger: logger,
userFailures: make(map[int]*UserFailureInfo),
timeNow: time.Now, // Default to real time
}
// Handle startup pause if configured
if w.workerCfg.StartWorkerPaused {
w.handleStartupPause(ctx)
}
// Store cancel function for cleanup
w.cancel = cancel
return w
}
// getEnvBool is a helper function to get boolean environment variables
func getEnvBool(key string, defaultValue bool) bool {
valStr := os.Getenv(key)
if valStr == "" {
return defaultValue
}
val, err := strconv.ParseBool(valStr)
if err != nil {
return defaultValue
}
return val
}
// Start begins the worker's background processing loop
func (w *Worker) Start(ctx context.Context) {
w.status.IsRunning = true
w.updateDatabaseStatus(ctx)
w.handleStartupPause(ctx)
// Start heartbeat goroutine
go w.heartbeatLoop(ctx)
// Main worker loop
ticker := time.NewTicker(config.WorkerHeartbeatInterval)
defer ticker.Stop()
initialStatus := w.getInitialWorkerStatus(ctx)
w.logger.Info(ctx, "Worker started", map[string]any{
"instance": w.instance,
"status": initialStatus,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s started (%s)", w.instance, initialStatus), nil, nil)
for {
select {
case <-ctx.Done():
w.logger.Info(ctx, "Worker shutting down", map[string]any{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s shutting down", w.instance), nil, nil)
w.status.IsRunning = false
w.updateDatabaseStatus(ctx)
return
case <-ticker.C:
w.run()
case <-w.manualTrigger:
w.logger.Info(ctx, "Worker triggered manually", map[string]any{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s triggered manually", w.instance), nil, nil)
w.run()
}
}
}
// handleStartupPause sets global pause if configured
func (w *Worker) handleStartupPause(ctx context.Context) {
if w.workerCfg.StartWorkerPaused {
w.logger.Info(ctx, "Worker configured to start paused - setting global pause", map[string]interface{}{
"instance": w.instance,
})
if err := w.workerService.SetGlobalPause(ctx, true); err != nil {
w.logger.Error(ctx, "Failed to set global pause on startup", err, map[string]interface{}{
"instance": w.instance,
})
} else {
w.logger.Info(ctx, "Global pause set on startup as configured", map[string]interface{}{
"instance": w.instance,
})
}
}
}
// getInitialWorkerStatus determines the initial status string
func (w *Worker) getInitialWorkerStatus(ctx context.Context) string {
initialStatus := "running"
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
w.logger.Error(ctx, "Failed to check global pause status on startup", err, map[string]interface{}{
"instance": w.instance,
})
} else if globalPaused {
initialStatus = "paused (globally)"
} else {
status, err := w.workerService.GetWorkerStatus(ctx, w.instance)
if err != nil {
// Worker status not found is expected on first startup - this is normal
w.logger.Debug(ctx, "Worker status not found on startup (expected for new worker)", map[string]interface{}{
"instance": w.instance,
})
} else if status != nil && status.IsPaused {
initialStatus = "paused (instance)"
}
}
return initialStatus
}
func (w *Worker) heartbeatLoop(ctx context.Context) {
ticker := time.NewTicker(config.WorkerHeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.updateHeartbeat(ctx)
}
}
}
// updateHeartbeat updates the heartbeat in the database
func (w *Worker) updateHeartbeat(ctx context.Context) {
if err := w.workerService.UpdateHeartbeat(ctx, w.instance); err != nil {
w.logger.Error(ctx, "Failed to update heartbeat for worker", err, map[string]any{
"instance": w.instance,
})
}
}
// run executes a single worker cycle
func (w *Worker) run() {
ctx, span := observability.TraceWorkerFunction(context.Background(), "run",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
// Ensure worker status is up to date before checking pause status
w.updateDatabaseStatus(ctx)
paused, reason := w.checkPauseStatus(ctx)
if paused {
span.SetAttributes(attribute.String("pause_reason", reason))
w.updateActivity(reason)
return
}
w.status.LastRunStart = time.Now()
w.updateDatabaseStatus(ctx)
details, err := w.generateNeededQuestions(ctx)
// Assign daily questions to all eligible users (independent of email reminders)
if err := w.checkForDailyQuestionAssignments(ctx); err != nil {
w.logger.Error(ctx, "Failed to check daily question assignments", err, map[string]interface{}{
"instance": w.instance,
})
}
// Generate story sections for users with active stories
if err := w.checkForStoryGenerations(ctx); err != nil {
w.logger.Error(ctx, "Failed to check story generations", err, map[string]interface{}{
"instance": w.instance,
})
}
// Check for daily email reminders
if err := w.checkForDailyReminders(ctx); err != nil {
w.logger.Error(ctx, "Failed to check daily reminders", err, map[string]interface{}{
"instance": w.instance,
})
}
// Check for word of the day assignments
if err := w.checkForWordOfTheDayAssignments(ctx); err != nil {
w.logger.Error(ctx, "Failed to check word of the day assignments", err, map[string]interface{}{
"instance": w.instance,
})
}
// Check for word of the day emails
if err := w.checkForWordOfTheDayEmails(ctx); err != nil {
w.logger.Error(ctx, "Failed to check word of the day emails", err, map[string]interface{}{
"instance": w.instance,
})
}
// Cleanup expired translation cache entries (once per day)
if err := w.cleanupTranslationCache(ctx); err != nil {
w.logger.Error(ctx, "Failed to cleanup translation cache", err, map[string]interface{}{
"instance": w.instance,
})
}
w.status.LastRunFinish = time.Now()
if err != nil {
w.status.LastRunError = err.Error()
w.logger.Error(ctx, "Worker run failed", err, map[string]interface{}{
"instance": w.instance,
})
} else {
w.status.LastRunError = ""
}
w.recordRunHistory(details, err)
w.updateDatabaseStatus(ctx)
}
// checkPauseStatus checks global and instance pause
func (w *Worker) checkPauseStatus(ctx context.Context) (bool, string) {
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
w.logger.Error(ctx, "Failed to check global pause status", err, map[string]interface{}{
"instance": w.instance,
})
return true, "Error checking global pause status"
}
if globalPaused {
return true, "Globally paused"
}
status, err := w.workerService.GetWorkerStatus(ctx, w.instance)
if err != nil {
// Worker status not found might happen during startup - assume not paused
w.logger.Debug(ctx, "Worker status not found during pause check (assuming not paused)", map[string]interface{}{
"instance": w.instance,
})
return false, ""
} else if status != nil && status.IsPaused {
return true, "Worker instance paused"
}
return false, ""
}
// recordRunHistory records the run in history and trims the slice
func (w *Worker) recordRunHistory(details string, err error) {
record := RunRecord{
StartTime: w.status.LastRunStart,
EndTime: w.status.LastRunFinish,
Duration: w.status.LastRunFinish.Sub(w.status.LastRunStart),
Details: details,
}
if err != nil {
record.Status = "Failure"
} else {
record.Status = "Success"
}
w.mu.Lock()
w.history = append(w.history, record)
if len(w.history) > w.cfg.Server.MaxHistory {
w.history = w.history[len(w.history)-w.cfg.Server.MaxHistory:]
}
w.mu.Unlock()
}
// GetStatus returns the current worker status
func (w *Worker) GetStatus() Status {
w.mu.RLock()
defer w.mu.RUnlock()
return w.status
}
// GetHistory returns the worker's run history
func (w *Worker) GetHistory() []RunRecord {
w.mu.RLock()
defer w.mu.RUnlock()
// Return a copy to avoid race conditions
history := make([]RunRecord, len(w.history))
copy(history, w.history)
return history
}
// checkForStoryGenerations checks for users with active stories and generates new sections
func (w *Worker) checkForStoryGenerations(ctx context.Context) error {
ctx, span := observability.TraceWorkerFunction(ctx, "check_story_generations",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.updateActivity("Checking for story generations...")
// Get all users with current active stories
users, err := w.getUsersWithActiveStories(ctx)
if err != nil {
return contextutils.WrapErrorf(err, "failed to get users with active stories")
}
w.logger.Info(ctx, "Found users with active stories",
map[string]interface{}{
"count": len(users),
"instance": w.instance,
})
processed := 0
for _, user := range users {
if err := w.generateStorySection(ctx, user); err != nil {
// Check if this is a generation limit reached error (normal case for worker)
if errors.Is(err, contextutils.ErrGenerationLimitReached) {
w.logger.Info(ctx, "User reached daily generation limit, skipping",
map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"instance": w.instance,
})
} else {
w.logger.Error(ctx, "Failed to generate story section for user",
err, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"instance": w.instance,
})
}
continue
}
processed++
}
w.updateActivity(fmt.Sprintf("Generated story sections for %d users", processed))
w.logger.Info(ctx, "Story generation completed",
map[string]interface{}{
"processed": processed,
"total": len(users),
"instance": w.instance,
})
return nil
}
// generateStorySection generates a new section for a user's current story
func (w *Worker) generateStorySection(ctx context.Context, user models.User) error {
ctx, span := observability.TraceWorkerFunction(ctx, "generate_story_section",
attribute.String("worker.instance", w.instance),
attribute.String("user.username", user.Username),
attribute.Int("user.id", int(user.ID)),
)
defer observability.FinishSpan(span, nil)
// Create a timeout context for story generation to prevent hanging requests
// Use the configured AI request timeout for consistency with other AI operations
timeoutCtx, cancel := context.WithTimeout(ctx, config.AIRequestTimeout)
defer cancel()
// Get the user's current story
story, err := w.storyService.GetCurrentStory(timeoutCtx, uint(user.ID))
if err != nil {
return contextutils.WrapErrorf(err, "failed to get current story for user %d", user.ID)
}
if story == nil {
// No current story, skip
return nil
}
// Get user's AI configuration
userConfig, apiKeyID := w.getUserAIConfig(timeoutCtx, &user)
// Add user ID and API key ID to context for usage tracking
timeoutCtx = contextutils.WithUserID(timeoutCtx, user.ID)
if apiKeyID != nil {
timeoutCtx = contextutils.WithAPIKeyID(timeoutCtx, *apiKeyID)
}
// Generate the story section using the shared service method (worker generation)
_, err = w.storyService.GenerateStorySection(timeoutCtx, story.ID, uint(user.ID), w.aiService, userConfig, models.GeneratorTypeWorker)
if err != nil {
// Check if this is a generation limit reached error (normal case for worker)
if errors.Is(err, contextutils.ErrGenerationLimitReached) {
w.logger.Info(ctx, "User reached daily generation limit, skipping",
map[string]interface{}{
"user_id": user.ID,
"story_id": story.ID,
})
return nil // Skip this user, not an error
}
return contextutils.WrapErrorf(err, "failed to generate story section")
}
return nil
}
// getUsersWithActiveStories retrieves all users who have current active stories
func (w *Worker) getUsersWithActiveStories(ctx context.Context) ([]models.User, error) {
// Get all users first
allUsers, err := w.userService.GetAllUsers(ctx)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to get all users")
}
// Filter to only users with current active stories and AI enabled
var filteredUsers []models.User
for _, user := range allUsers {
// Check if user has AI enabled
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
continue
}
// Check if user has valid AI provider and model
if !user.AIProvider.Valid || !user.AIModel.Valid {
continue
}
// Check if user has a current active story
story, err := w.storyService.GetCurrentStory(ctx, uint(user.ID))
if err != nil || story == nil {
continue
}
// Check if story is active
if story.Status != models.StoryStatusActive {
continue
}
// Check if auto-generation is paused for this story
if story.AutoGenerationPaused {
w.logger.Debug(ctx, "Skipping story with auto-generation paused",
map[string]interface{}{
"user_id": user.ID,
"story_id": story.ID,
})
continue
}
filteredUsers = append(filteredUsers, user)
}
return filteredUsers, nil
}
// GetActivityLogs returns recent activity logs
func (w *Worker) GetActivityLogs() []ActivityLog {
w.mu.RLock()
defer w.mu.RUnlock()
// Return a copy to avoid concurrent access issues
logs := make([]ActivityLog, len(w.activityLogs))
copy(logs, w.activityLogs)
return logs
}
// GetInstance returns the worker instance name
func (w *Worker) GetInstance() string {
return w.instance
}
// GetEmailService returns the email service
func (w *Worker) GetEmailService() mailer.Mailer {
return w.emailService
}
// TriggerManualRun triggers a manual worker run
func (w *Worker) TriggerManualRun() {
ctx := context.Background()
select {
case w.manualTrigger <- true:
w.logger.Info(ctx, "Manual trigger sent to worker", map[string]interface{}{
"instance": w.instance,
})
default:
w.logger.Info(ctx, "Manual trigger already pending for worker", map[string]interface{}{
"instance": w.instance,
})
}
}
// Pause pauses the worker
func (w *Worker) Pause(ctx context.Context) {
if err := w.workerService.PauseWorker(ctx, w.instance); err != nil {
w.logger.Warn(ctx, "Failed to pause worker in service", map[string]interface{}{
"instance": w.instance,
"error": err.Error(),
})
}
w.logger.Info(ctx, "Worker paused", map[string]interface{}{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s paused", w.instance), nil, nil)
w.status.IsPaused = true
w.updateDatabaseStatus(ctx)
}
// Resume resumes the worker
func (w *Worker) Resume(ctx context.Context) {
if err := w.workerService.ResumeWorker(ctx, w.instance); err != nil {
w.logger.Warn(ctx, "Failed to resume worker in service", map[string]interface{}{
"instance": w.instance,
"error": err.Error(),
})
// Do not unpause if resume failed
w.updateDatabaseStatus(ctx)
return
}
w.logger.Info(ctx, "Worker resumed", map[string]interface{}{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s resumed", w.instance), nil, nil)
w.status.IsPaused = false
w.updateDatabaseStatus(ctx)
}
// Shutdown gracefully shuts down the worker and cleans up resources
func (w *Worker) Shutdown(ctx context.Context) error {
w.mu.Lock()
defer w.mu.Unlock()
w.logger.Info(ctx, "Worker starting shutdown", map[string]interface{}{
"instance": w.instance,
})
// Cancel the shutdown context to signal shutdown
if w.cancel != nil {
w.cancel()
}
// Wait for any active operations to complete
// This is a simple implementation - in a more complex system,
// you might want to track active operations more precisely
time.Sleep(config.WorkerSleepDuration)
// Clean up user failures map
w.failureMu.Lock()
w.userFailures = make(map[int]*UserFailureInfo)
w.failureMu.Unlock()
// Clear activity logs
w.activityLogs = make([]ActivityLog, 0)
w.logger.Info(ctx, "Worker shutdown completed", map[string]interface{}{
"instance": w.instance,
})
return nil
}
// updateDatabaseStatus updates the worker status in the database
func (w *Worker) updateDatabaseStatus(ctx context.Context) {
dbStatus := &models.WorkerStatus{
WorkerInstance: w.instance,
IsRunning: w.status.IsRunning,
IsPaused: w.status.IsPaused,
CurrentActivity: sql.NullString{String: w.status.CurrentActivity, Valid: w.status.CurrentActivity != ""},
LastHeartbeat: sql.NullTime{Time: time.Now(), Valid: true},
LastRunStart: sql.NullTime{Time: w.status.LastRunStart, Valid: !w.status.LastRunStart.IsZero()},
LastRunFinish: sql.NullTime{Time: w.status.LastRunFinish, Valid: !w.status.LastRunFinish.IsZero()},
LastRunError: sql.NullString{String: w.status.LastRunError, Valid: w.status.LastRunError != ""},
TotalQuestionsGenerated: w.getTotalQuestionsGenerated(),
TotalRuns: len(w.history),
}
if err := w.workerService.UpdateWorkerStatus(ctx, w.instance, dbStatus); err != nil {
w.logger.Error(ctx, "Failed to update worker status in database", err, map[string]interface{}{
"instance": w.instance,
})
}
}
// getTotalQuestionsGenerated calculates total questions generated from run history
func (w *Worker) getTotalQuestionsGenerated() int {
total := 0
for _, record := range w.history {
if record.Status == "Success" {
// Parse details to count questions - simplified for now
total++ // This would need to be enhanced to parse actual count
}
}
return total
}
func (w *Worker) generateNeededQuestions(ctx context.Context) (result0 string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "generate_needed_questions",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
// Check if globally paused BEFORE any work or logging
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to check global pause status", err, map[string]interface{}{
"instance": w.instance,
})
return "Error checking global pause status", err
}
if globalPaused {
span.SetAttributes(attribute.Bool("globally_paused", true))
w.logger.Info(ctx, "Worker skipping question generation (globally paused)", map[string]interface{}{
"instance": w.instance,
})
return "Run paused globally", nil
}
aiUsers, err := w.getEligibleAIUsers(ctx)
if err != nil {
return "Error getting users", err
}
if len(aiUsers) == 0 {
w.logger.Info(ctx, "Worker: No active users with AI provider configuration found for question generation", map[string]interface{}{
"instance": w.instance,
})
return "No active users with AI provider configuration found", nil
}
var actions []string
var checkedUsers []string
var actuallyProcessedUsers []string
var hadAttemptedOperations bool
var hadFailures bool
for _, user := range aiUsers {
checkedUsers = append(checkedUsers, user.Username)
shouldProcess, skipReason := w.shouldProcessUser(ctx, &user)
if !shouldProcess {
if skipReason != "" {
w.logger.Info(ctx, "Worker user check", map[string]interface{}{
"instance": w.instance,
"username": user.Username,
"reason": skipReason,
})
}
continue
}
actuallyProcessedUsers = append(actuallyProcessedUsers, user.Username)
userActions, attempted, failed := w.processUserQuestionGeneration(ctx, &user)
if attempted {
hadAttemptedOperations = true
}
if failed {
hadFailures = true
}
if userActions != "" {
actions = append(actions, userActions)
}
w.logger.Info(ctx, "Worker completed check for user", map[string]interface{}{
"instance": w.instance,
"username": user.Username,
})
}
w.updateActivity("")
return w.summarizeRunActions(actions, checkedUsers, actuallyProcessedUsers, hadAttemptedOperations, hadFailures), nil
}
// getEligibleAIUsers returns users eligible for AI question generation
func (w *Worker) getEligibleAIUsers(ctx context.Context) (result0 []models.User, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_eligible_ai_users",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, err
}
var aiUsers []models.User
for _, user := range users {
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
continue
}
userPaused, err := w.workerService.IsUserPaused(ctx, user.ID)
if err == nil && userPaused {
continue
}
hasAIProvider := user.AIProvider.Valid && user.AIProvider.String != ""
hasAPIKey := false
if hasAIProvider {
savedKey, err := w.userService.GetUserAPIKey(ctx, user.ID, user.AIProvider.String)
if err == nil && savedKey != "" {
hasAPIKey = true
}
}
if hasAPIKey || hasAIProvider {
aiUsers = append(aiUsers, user)
}
}
return aiUsers, nil
}
// shouldProcessUser encapsulates exponential backoff and pause checks
func (w *Worker) shouldProcessUser(ctx context.Context, user *models.User) (bool, string) {
if !w.shouldRetryUser(user.ID) {
w.failureMu.RLock()
failure := w.userFailures[user.ID]
nextRetry := time.Until(failure.NextRetryTime)
w.failureMu.RUnlock()
return false, fmt.Sprintf("Skipping due to exponential backoff (failure #%d, retry in %v)", failure.ConsecutiveFailures, nextRetry.Round(time.Second))
}
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
return false, "Error checking global pause status"
}
if globalPaused {
return false, "Run paused globally"
}
status, err := w.workerService.GetWorkerStatus(ctx, w.instance)
if err == nil && status != nil && status.IsPaused {
return false, fmt.Sprintf("Worker instance %s paused", w.instance)
}
if ctx.Err() != nil {
return false, "Shutdown initiated"
}
return true, ""
}
// Helper: get the count of eligible questions for a user (excludes questions answered correctly in the last 2 days)
func (w *Worker) getEligibleQuestionCount(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 int, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_eligible_question_count",
observability.AttributeUserID(userID),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
// Safe user lookup: tests may not wire userService
userLookup := func(ctx context.Context, id int) (*models.User, error) {
// Only use the concrete UserService implementation to avoid invoking mocks in unit tests
if us, ok := w.userService.(*services.UserService); ok && us != nil {
return us.GetUserByID(ctx, id)
}
// No userService available or not concrete - return nil so helper falls back to UTC
return nil, nil
}
// Determine user-local 2-day window and pass UTC timestamps to query
startUTC, endUTC, _, err := contextutils.UserLocalDayRange(ctx, userID, 2, userLookup)
if err != nil {
return 0, contextutils.WrapError(err, "failed to compute user local day range")
}
query := `
SELECT COUNT(*)
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.status = 'active'
AND NOT EXISTS (
SELECT 1 FROM user_responses ur
WHERE ur.user_id = $1
AND ur.question_id = q.id
AND ur.is_correct = TRUE
AND ur.created_at >= $5 AND ur.created_at < $6
)
`
// Try to get the database from the question service
var db *sql.DB
if qs, ok := w.questionService.(*services.QuestionService); ok {
db = qs.DB()
} else {
// For mock services or other implementations, we can't get the DB directly
// This is expected in unit tests
return 0, contextutils.ErrorWithContextf("cannot get database from question service implementation")
}
row := db.QueryRowContext(ctx, query, userID, language, level, qType, startUTC, endUTC)
var count int
if err := row.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (w *Worker) processUserQuestionGeneration(ctx context.Context, user *models.User) (string, bool, bool) {
ctx, span := observability.TraceWorkerFunction(ctx, "processUserQuestionGeneration",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
userLanguage := "italian"
if user.PreferredLanguage.Valid && user.PreferredLanguage.String != "" {
userLanguage = user.PreferredLanguage.String
span.SetAttributes(attribute.String("user.language", userLanguage))
}
userLevel := "A1"
if user.CurrentLevel.Valid && user.CurrentLevel.String != "" {
userLevel = user.CurrentLevel.String
span.SetAttributes(attribute.String("user.level", userLevel))
}
languages := []string{userLanguage}
levels := []string{userLevel}
questionTypes := []models.QuestionType{
models.Vocabulary,
models.FillInBlank,
models.QuestionAnswer,
models.ReadingComprehension,
}
// Reorder types based on active generation hints (hinted types first, stable order)
if w.hintService != nil {
if hints, err := w.hintService.GetActiveHintsForUser(ctx, user.ID); err == nil && len(hints) > 0 {
hinted := make([]models.QuestionType, 0, len(hints))
hintedSet := map[models.QuestionType]bool{}
for _, h := range hints {
qt := models.QuestionType(h.QuestionType)
hinted = append(hinted, qt)
hintedSet[qt] = true
}
rest := make([]models.QuestionType, 0, len(questionTypes))
for _, qt := range questionTypes {
if !hintedSet[qt] {
rest = append(rest, qt)
}
}
questionTypes = append(hinted, rest...)
}
}
var actions []string
var hadAttemptedOperations bool
var hadFailures bool
for _, language := range languages {
for _, level := range levels {
for _, qType := range questionTypes {
activity := fmt.Sprintf("Checking questions for user %s: %s %s %s", user.Username, language, level, qType)
w.updateActivity(activity)
// Use eligible question count (not just total assigned)
eligibleCount, err := w.getEligibleQuestionCount(ctx, user.ID, language, level, qType)
if err != nil {
span.RecordError(err)
hadFailures = true
continue // Continue to next question type
}
// If hinted, be more aggressive about generating for that type
hinted := false
if w.hintService != nil {
if hints, err := w.hintService.GetActiveHintsForUser(ctx, user.ID); err == nil {
for _, h := range hints {
if models.QuestionType(h.QuestionType) == qType {
hinted = true
break
}
}
}
}
refillThreshold := w.cfg.Server.QuestionRefillThreshold
if hinted {
// Treat as if pool is empty to trigger generation, but keep batch sizing logic
eligibleCount = 0
}
if eligibleCount < refillThreshold {
provider := "default"
if user.AIProvider.Valid && user.AIProvider.String != "" {
provider = user.AIProvider.String
}
// Base batch size from AI provider
needed := w.aiService.GetQuestionBatchSize(provider)
// Get user's learning preferences to use their personal FreshQuestionRatio
userPrefs, prefsErr := w.learningService.GetUserLearningPreferences(ctx, user.ID)
userFreshRatio := 0.7 // default fallback
if prefsErr == nil && userPrefs != nil && userPrefs.FreshQuestionRatio > 0 {
userFreshRatio = userPrefs.FreshQuestionRatio
} else if prefsErr != nil {
w.logger.Warn(ctx, "Failed to get user learning preferences, using default fresh ratio", map[string]interface{}{
"user_id": user.ID,
"error": prefsErr.Error(),
})
}
// Ensure at least enough fresh questions are available to meet the user's personal FreshQuestionRatio.
// This ensures daily question assignment can respect the user's freshness preference.
desiredFresh := int(math.Ceil(float64(refillThreshold) * userFreshRatio))
freshCandidates := 0
if qs, qerr := w.questionService.GetAdaptiveQuestionsForDaily(ctx, user.ID, language, level, 50); qerr == nil && qs != nil {
for _, q := range qs {
if q != nil && q.TotalResponses == 0 {
freshCandidates++
}
}
} else if qerr != nil {
// Log but don't fail - we'll conservatively proceed with base batch size
w.logger.Warn(ctx, "Failed to fetch adaptive questions for fresh-count check", map[string]interface{}{
"user_id": user.ID,
"error": qerr.Error(),
})
}
if missing := desiredFresh - freshCandidates; missing > 0 {
needed += missing
w.logger.Info(ctx, "Adjusting generation batch to meet user's personal fresh-question requirement", map[string]interface{}{
"user_id": user.ID,
"language": language,
"level": level,
"question_type": qType,
"user_fresh_ratio": userFreshRatio,
"base_batch_size": w.aiService.GetQuestionBatchSize(provider),
"desired_fresh": desiredFresh,
"fresh_candidates": freshCandidates,
"added_to_batch": missing,
"final_batch_size": needed,
})
}
hadAttemptedOperations = true
action, err := w.GenerateQuestionsForUser(ctx, user, language, level, qType, needed, "")
if err != nil {
hadFailures = true
// Continue to next question type instead of breaking all loops
continue
}
if action != "" {
actions = append(actions, action)
}
// Clear hint on successful generation attempt for this type
if hinted && w.hintService != nil {
_ = w.hintService.ClearHint(ctx, user.ID, language, level, qType)
}
}
}
}
}
return strings.Join(actions, "; "), hadAttemptedOperations, hadFailures
}
// summarizeRunActions builds the summary string for actions taken
func (w *Worker) summarizeRunActions(actions, checkedUsers, actuallyProcessedUsers []string, hadAttemptedOperations, hadFailures bool) string {
userList := "No users with AI configuration found"
if len(checkedUsers) > 0 {
userList = fmt.Sprintf("Checked users: %s", strings.Join(checkedUsers, ", "))
}
if len(actions) == 0 {
if len(actuallyProcessedUsers) == 0 {
return fmt.Sprintf("No actions taken. All users in exponential backoff. %s", userList)
}
if hadAttemptedOperations && hadFailures && len(actions) == 0 {
return fmt.Sprintf("No actions taken due to errors. %s", userList)
}
return fmt.Sprintf("No actions taken. All question types have sufficient questions. %s", userList)
}
userList = fmt.Sprintf("Processed users: %s", strings.Join(actuallyProcessedUsers, ", "))
// Format actions with line breaks for better readability in UI
if len(actions) == 1 {
return fmt.Sprintf("%s\n%s", actions[0], userList)
}
formattedActions := strings.Join(actions, "\n")
return fmt.Sprintf("%s\n%s", formattedActions, userList)
}
// GenerateQuestionsForUser generates questions for a specific user with the given parameters
func (w *Worker) GenerateQuestionsForUser(ctx context.Context, user *models.User, language, level string, qType models.QuestionType, count int, topic string) (result0 string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "generate_questions_for_user",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", count),
attribute.String("topic", topic),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
if count <= 0 {
return "No questions needed", nil
}
// Gather priority data for variety selection
priorityData := w.getPriorityGenerationData(ctx, user.ID, language, level, qType)
var userWeakAreas []string
if priorityData != nil && priorityData.FocusOnWeakAreas {
userWeakAreas = priorityData.UserWeakAreas
}
var highPriorityTopics []string
if priorityData != nil {
highPriorityTopics = priorityData.HighPriorityTopics
}
var gapAnalysis map[string]int
if priorityData != nil {
gapAnalysis = priorityData.GapAnalysis
}
variety := w.aiService.VarietyService().SelectVarietyElements(ctx, level, highPriorityTopics, userWeakAreas, gapAnalysis)
// Log priority generation decisions
generationReasoning := w.getGenerationReasoning(priorityData, variety)
var freshQuestionRatio float64
if priorityData != nil {
freshQuestionRatio = priorityData.FreshQuestionRatio
}
priorityLog := PriorityGenerationLog{
UserID: user.ID,
Username: user.Username,
Language: language,
Level: level,
QuestionType: string(qType),
FocusOnWeakAreas: priorityData != nil && priorityData.FocusOnWeakAreas,
UserWeakAreas: userWeakAreas,
HighPriorityTopics: highPriorityTopics,
GapAnalysis: gapAnalysis,
FreshQuestionRatio: freshQuestionRatio,
SelectedVariety: variety,
GenerationReasoning: generationReasoning,
Timestamp: time.Now(),
}
w.logPriorityGeneration(ctx, priorityLog)
aiReq, recentQuestions, err := w.buildAIQuestionGenRequest(ctx, user, language, level, qType, count, topic)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get recent questions", map[string]interface{}{
"instance": w.instance,
"error": err.Error(),
})
return "", contextutils.WrapError(err, "failed to build AI request")
}
aiReq.RecentQuestionHistory = recentQuestions
userConfig, apiKeyID := w.getUserAIConfig(ctx, user)
batchLogMsg := formatBatchLogMessage(user.Username, count, string(qType), language, level, variety, userConfig.Provider, userConfig.Model)
w.logger.Info(ctx, batchLogMsg, map[string]interface{}{
"instance": w.instance,
})
w.updateActivity(batchLogMsg)
w.logActivity(ctx, "INFO", batchLogMsg, &user.ID, &user.Username)
progressMsg, questions, errAI := w.handleAIQuestionStream(ctx, userConfig, apiKeyID, aiReq, variety, count, language, level, qType, topic, user)
if errAI != nil {
w.recordUserFailure(ctx, user.ID, user.Username)
return progressMsg, errAI
}
if len(questions) == 0 {
w.recordUserFailure(ctx, user.ID, user.Username)
return progressMsg, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "AI service returned 0 questions for %s %s %s", language, level, qType)
}
savedCount := w.saveGeneratedQuestions(ctx, user, questions, language, level, qType, topic, variety)
if savedCount > 0 {
w.recordUserSuccess(ctx, user.ID, user.Username)
}
if savedCount != len(questions) {
w.recordUserFailure(ctx, user.ID, user.Username)
return fmt.Sprintf("Generated %d/%d %s questions for %s %s", savedCount, len(questions), qType, language, level),
contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "only saved %d out of %d generated questions", savedCount, len(questions))
}
return fmt.Sprintf("Generated %d %s questions for %s %s", savedCount, qType, language, level), nil
}
// buildAIQuestionGenRequest prepares the AI request and gets recent questions
func (w *Worker) buildAIQuestionGenRequest(ctx context.Context, user *models.User, language, level string, qType models.QuestionType, count int, _ string) (result0 *models.AIQuestionGenRequest, result1 []string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "build_ai_question_gen_request",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", count),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
recentQuestions, err := w.questionService.GetRecentQuestionContentsForUser(ctx, user.ID, 10)
if err != nil {
span.RecordError(err)
return nil, nil, err
}
aiReq := &models.AIQuestionGenRequest{
Language: language,
Level: level,
QuestionType: qType,
Count: count,
}
aiReq.RecentQuestionHistory = recentQuestions
return aiReq, recentQuestions, nil
}
// getUserAIConfig builds the UserAIConfig struct with API key and returns the API key ID
func (w *Worker) getUserAIConfig(ctx context.Context, user *models.User) (*models.UserAIConfig, *int) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_user_ai_config",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
provider := ""
if user.AIProvider.Valid {
provider = user.AIProvider.String
span.SetAttributes(attribute.String("ai.provider", provider))
}
model := ""
if user.AIModel.Valid {
model = user.AIModel.String
span.SetAttributes(attribute.String("ai.model", model))
}
apiKey := ""
var apiKeyID *int
if provider != "" {
savedKey, keyID, err := w.userService.GetUserAPIKeyWithID(ctx, user.ID, provider)
if err == nil && savedKey != "" {
apiKey = savedKey
apiKeyID = keyID
}
}
return &models.UserAIConfig{
Provider: provider,
Model: model,
APIKey: apiKey,
Username: user.Username,
}, apiKeyID
}
// handleAIQuestionStream handles the AI streaming and collects questions
func (w *Worker) handleAIQuestionStream(ctx context.Context, userConfig *models.UserAIConfig, apiKeyID *int, req *models.AIQuestionGenRequest, variety *services.VarietyElements, count int, language, level string, qType models.QuestionType, topic string, user *models.User) (result0 string, result1 []*models.Question, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "handle_ai_question_stream",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", count),
attribute.String("topic", topic),
attribute.String("user.username", user.Username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
// Add user ID and API key ID to context for usage tracking
ctx = contextutils.WithUserID(ctx, user.ID)
if apiKeyID != nil {
ctx = contextutils.WithAPIKeyID(ctx, *apiKeyID)
}
progressChan := make(chan *models.Question)
var questions []*models.Question
var wg sync.WaitGroup
var errAI error
progressMsg := ""
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
w.logger.Error(ctx, "Panic in AI question stream goroutine", nil, map[string]interface{}{
"instance": w.instance,
"panic": fmt.Sprintf("%v", r),
})
}
wg.Done()
}()
errAI = w.aiService.GenerateQuestionsStream(ctx, userConfig, req, progressChan, variety)
}()
generatedCount := 0
for question := range progressChan {
generatedCount++
progressMsg = fmt.Sprintf("Generated %d/%d %s questions for %s %s", generatedCount, count, qType, language, level)
if topic != "" {
progressMsg = fmt.Sprintf("Generated %d/%d %s questions for %s %s (topic: %s)", generatedCount, count, qType, language, level, topic)
}
w.logger.Info(ctx, progressMsg, map[string]interface{}{
"instance": w.instance,
})
w.updateActivity(progressMsg)
w.logActivity(ctx, "INFO", progressMsg, &user.ID, &user.Username)
questions = append(questions, question)
}
wg.Wait()
return progressMsg, questions, errAI
}
// saveGeneratedQuestions saves questions to the DB and returns the count
func (w *Worker) saveGeneratedQuestions(ctx context.Context, user *models.User, questions []*models.Question, language, level string, qType models.QuestionType, topic string, variety *services.VarietyElements) int {
ctx, span := observability.TraceWorkerFunction(ctx, "save_generated_questions",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", len(questions)),
attribute.String("topic", topic),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
savingMsg := fmt.Sprintf("Saving %d new %s questions for %s %s", len(questions), qType, language, level)
if topic != "" {
savingMsg = fmt.Sprintf("Saving %d new %s questions for %s %s (topic: %s)", len(questions), qType, language, level, topic)
}
w.logger.Info(ctx, savingMsg, map[string]interface{}{
"instance": w.instance,
})
w.updateActivity(savingMsg)
w.logActivity(ctx, "INFO", savingMsg, &user.ID, &user.Username)
savedCount := 0
for _, q := range questions {
// Populate variety fields from the variety elements used during generation
if variety != nil {
q.TopicCategory = variety.TopicCategory
q.GrammarFocus = variety.GrammarFocus
q.VocabularyDomain = variety.VocabularyDomain
q.Scenario = variety.Scenario
q.StyleModifier = variety.StyleModifier
q.DifficultyModifier = variety.DifficultyModifier
q.TimeContext = variety.TimeContext
}
if err := w.questionService.SaveQuestion(ctx, q); err != nil {
w.logger.Error(ctx, "Failed to save generated question", err, map[string]interface{}{
"instance": w.instance,
"user_id": user.ID,
"language": language,
"level": level,
"question_type": qType,
})
} else {
// Assign the question to the user after saving
if err := w.questionService.AssignQuestionToUser(ctx, q.ID, user.ID); err != nil {
w.logger.Error(ctx, "Failed to assign question to user", err, map[string]interface{}{
"instance": w.instance,
"question_id": q.ID,
"user_id": user.ID,
})
} else {
savedCount++
}
}
}
if savedCount > 0 {
successMsg := fmt.Sprintf("Successfully saved %d new '%s' questions for %s %s", savedCount, qType, language, level)
if topic != "" {
successMsg = fmt.Sprintf("Successfully saved %d new '%s' questions for %s %s (topic: %s)", savedCount, qType, language, level, topic)
}
w.logActivity(ctx, "INFO", successMsg, &user.ID, &user.Username)
}
return savedCount
}
func (w *Worker) updateActivity(activity string) {
w.mu.Lock()
defer w.mu.Unlock()
w.status.CurrentActivity = activity
}
// logActivity adds an activity log entry
func (w *Worker) logActivity(_ context.Context, _, message string, userID *int, username *string) {
w.mu.Lock()
defer w.mu.Unlock()
logEntry := ActivityLog{
Timestamp: time.Now(),
Level: "INFO",
Message: message,
UserID: userID,
Username: username,
}
// Add to activity logs (circular buffer)
w.activityLogs = append(w.activityLogs, logEntry)
// Keep only the last maxActivityLogs entries
if len(w.activityLogs) > w.cfg.Server.MaxActivityLogs {
w.activityLogs = w.activityLogs[len(w.activityLogs)-w.cfg.Server.MaxActivityLogs:]
}
}
// shouldRetryUser checks if enough time has passed since the last failure for exponential backoff
func (w *Worker) shouldRetryUser(userID int) bool {
w.failureMu.RLock()
defer w.failureMu.RUnlock()
failure, exists := w.userFailures[userID]
if !exists {
return true // No previous failures, go ahead
}
return time.Now().After(failure.NextRetryTime)
}
// recordUserFailure records a failure and calculates the next retry time with exponential backoff
func (w *Worker) recordUserFailure(ctx context.Context, userID int, username string) {
ctx, span := observability.TraceWorkerFunction(ctx, "record_user_failure",
observability.AttributeUserID(userID),
attribute.String("user.username", username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.failureMu.Lock()
defer w.failureMu.Unlock()
failure, exists := w.userFailures[userID]
if !exists {
failure = &UserFailureInfo{}
w.userFailures[userID] = failure
}
failure.ConsecutiveFailures++
failure.LastFailureTime = time.Now()
// Exponential backoff: 2^failures seconds, max 1 hour
backoffSeconds := int(math.Pow(2, float64(failure.ConsecutiveFailures)))
if backoffSeconds > 3600 {
backoffSeconds = 3600
}
failure.NextRetryTime = time.Now().Add(time.Duration(backoffSeconds) * time.Second)
span.SetAttributes(
attribute.Int("failure.count", failure.ConsecutiveFailures),
attribute.Int("backoff.seconds", backoffSeconds),
)
w.logger.Info(ctx, "Worker recorded user failure", map[string]interface{}{
"instance": w.instance,
"username": username,
"failure_count": failure.ConsecutiveFailures,
"next_retry_seconds": backoffSeconds,
})
}
// recordUserSuccess clears the failure count for a user
func (w *Worker) recordUserSuccess(ctx context.Context, userID int, username string) {
ctx, span := observability.TraceWorkerFunction(ctx, "record_user_success",
observability.AttributeUserID(userID),
attribute.String("user.username", username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.failureMu.Lock()
defer w.failureMu.Unlock()
failure, exists := w.userFailures[userID]
if exists && failure.ConsecutiveFailures > 0 {
span.SetAttributes(attribute.Int("previous_failures", failure.ConsecutiveFailures))
w.logger.Info(ctx, "Worker user success after failures, resetting backoff", map[string]interface{}{
"instance": w.instance,
"username": username,
"previous_failures": failure.ConsecutiveFailures,
})
delete(w.userFailures, userID)
}
}
// formatBatchLogMessage creates a formatted log message for batch question generation
func formatBatchLogMessage(username string, count int, qType, language, level string, variety *services.VarietyElements, provider, model string) string {
var summaryFields []string
if variety != nil {
if variety.GrammarFocus != "" {
summaryFields = append(summaryFields, "grammar: "+variety.GrammarFocus)
}
if variety.TopicCategory != "" {
summaryFields = append(summaryFields, "topic: "+variety.TopicCategory)
}
if variety.Scenario != "" {
summaryFields = append(summaryFields, "scenario: "+variety.Scenario)
}
if variety.StyleModifier != "" {
summaryFields = append(summaryFields, "style: "+variety.StyleModifier)
}
if variety.DifficultyModifier != "" {
summaryFields = append(summaryFields, "difficulty: "+variety.DifficultyModifier)
}
if variety.VocabularyDomain != "" {
summaryFields = append(summaryFields, "vocab: "+variety.VocabularyDomain)
}
if variety.TimeContext != "" {
summaryFields = append(summaryFields, "time: "+variety.TimeContext)
}
}
providerModel := "provider: " + provider + ", model: " + model
if len(summaryFields) > 0 {
summaryFields = append(summaryFields, providerModel)
} else {
summaryFields = []string{providerModel}
}
return fmt.Sprintf("Worker [user=%s]: Batch %d %s questions (lang: %s, level: %s) | %s", username, count, qType, language, level, strings.Join(summaryFields, " | "))
}
// PriorityGenerationData contains priority information to guide AI question generation
type PriorityGenerationData struct {
UserWeakAreas []string `json:"user_weak_areas,omitempty"`
HighPriorityTopics []string `json:"high_priority_topics,omitempty"`
GapAnalysis map[string]int `json:"gap_analysis,omitempty"`
UserPreferences *models.UserLearningPreferences `json:"user_preferences,omitempty"`
PriorityDistribution map[string]int `json:"priority_distribution,omitempty"`
FocusOnWeakAreas bool `json:"focus_on_weak_areas"`
FreshQuestionRatio float64 `json:"fresh_question_ratio"`
}
// PriorityGenerationLog contains structured data about priority-aware generation decisions
type PriorityGenerationLog struct {
UserID int `json:"user_id"`
Username string `json:"username"`
Language string `json:"language"`
Level string `json:"level"`
QuestionType string `json:"question_type"`
FocusOnWeakAreas bool `json:"focus_on_weak_areas"`
UserWeakAreas []string `json:"user_weak_areas,omitempty"`
HighPriorityTopics []string `json:"high_priority_topics,omitempty"`
GapAnalysis map[string]int `json:"gap_analysis,omitempty"`
FreshQuestionRatio float64 `json:"fresh_question_ratio"`
SelectedVariety *services.VarietyElements `json:"selected_variety"`
GenerationReasoning string `json:"generation_reasoning"`
Timestamp time.Time `json:"timestamp"`
}
// logPriorityGeneration logs priority generation data as JSON
func (w *Worker) logPriorityGeneration(ctx context.Context, priorityLog PriorityGenerationLog) {
ctx, span := observability.TraceWorkerFunction(ctx, "log_priority_generation",
observability.AttributeUserID(priorityLog.UserID),
attribute.String("user.username", priorityLog.Username),
attribute.String("language", priorityLog.Language),
attribute.String("level", priorityLog.Level),
attribute.String("question.type", priorityLog.QuestionType),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
logJSON, err := json.Marshal(priorityLog)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to marshal priority generation log", err, map[string]interface{}{
"instance": w.instance,
})
return
}
w.logger.Info(ctx, "Worker priority generation", map[string]interface{}{
"instance": w.instance,
"data": string(logJSON),
})
}
// getGenerationReasoning provides a human-readable explanation of the generation strategy
func (w *Worker) getGenerationReasoning(priorityData *PriorityGenerationData, variety *services.VarietyElements) string {
if priorityData == nil {
return "standard generation"
}
var reasons []string
if priorityData.FocusOnWeakAreas && len(priorityData.UserWeakAreas) > 0 {
reasons = append(reasons, fmt.Sprintf("focusing on weak areas: %s", strings.Join(priorityData.UserWeakAreas, ", ")))
}
if len(priorityData.HighPriorityTopics) > 0 {
reasons = append(reasons, fmt.Sprintf("high priority topics: %s", strings.Join(priorityData.HighPriorityTopics, ", ")))
}
if len(priorityData.GapAnalysis) > 0 {
var gaps []string
for topic, count := range priorityData.GapAnalysis {
gaps = append(gaps, fmt.Sprintf("%s(%d)", topic, count))
}
reasons = append(reasons, fmt.Sprintf("gap analysis: %s", strings.Join(gaps, ", ")))
}
if priorityData.FreshQuestionRatio > 0 {
reasons = append(reasons, fmt.Sprintf("fresh ratio: %.1f%%", priorityData.FreshQuestionRatio*100))
}
if variety != nil {
var varietyElements []string
if variety.TopicCategory != "" {
varietyElements = append(varietyElements, fmt.Sprintf("topic:%s", variety.TopicCategory))
}
if variety.GrammarFocus != "" {
varietyElements = append(varietyElements, fmt.Sprintf("grammar:%s", variety.GrammarFocus))
}
if variety.VocabularyDomain != "" {
varietyElements = append(varietyElements, fmt.Sprintf("vocab:%s", variety.VocabularyDomain))
}
if variety.Scenario != "" {
varietyElements = append(varietyElements, fmt.Sprintf("scenario:%s", variety.Scenario))
}
if len(varietyElements) > 0 {
reasons = append(reasons, fmt.Sprintf("variety: %s", strings.Join(varietyElements, ", ")))
}
}
if len(reasons) == 0 {
return "standard generation"
}
return strings.Join(reasons, "; ")
}
// getPriorityGenerationData gathers priority data for AI question generation
func (w *Worker) getPriorityGenerationData(ctx context.Context, userID int, language, level string, questionType models.QuestionType) *PriorityGenerationData {
// Get user preferences
prefs, err := w.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get user preferences", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
prefs = w.getDefaultLearningPreferences()
}
// Get weak areas
weakAreas, err := w.learningService.GetUserWeakAreas(ctx, userID, 5)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get weak areas", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
weakAreas = []map[string]interface{}{}
}
// Convert weak areas to topic strings
var weakAreaTopics []string
for _, area := range weakAreas {
if topic, ok := area["topic"].(string); ok && topic != "" {
weakAreaTopics = append(weakAreaTopics, topic)
}
}
// Get high priority topics
highPriorityTopics, err := w.getHighPriorityTopics(ctx, userID, language, level, questionType)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get high priority topics", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
highPriorityTopics = []string{}
}
// Get gap analysis
gapAnalysis, err := w.getGapAnalysis(ctx, userID, language, level, questionType)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get gap analysis", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
gapAnalysis = map[string]int{}
}
// Get priority distribution
priorityDistribution, err := w.getPriorityDistribution(ctx, userID, language, level, questionType)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get priority distribution", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
priorityDistribution = map[string]int{}
}
// Determine if we should focus on weak areas
focusOnWeakAreas := len(weakAreaTopics) > 0 && prefs != nil && prefs.FocusOnWeakAreas
return &PriorityGenerationData{
UserWeakAreas: weakAreaTopics,
HighPriorityTopics: highPriorityTopics,
GapAnalysis: gapAnalysis,
UserPreferences: prefs,
PriorityDistribution: priorityDistribution,
FocusOnWeakAreas: focusOnWeakAreas,
FreshQuestionRatio: prefs.FreshQuestionRatio,
}
}
// getDefaultLearningPreferences returns default learning preferences
func (w *Worker) getDefaultLearningPreferences() *models.UserLearningPreferences {
return &models.UserLearningPreferences{
FocusOnWeakAreas: false,
FreshQuestionRatio: 0.3,
WeakAreaBoost: 1.5,
}
}
// getHighPriorityTopics returns topics that have high average priority scores
func (w *Worker) getHighPriorityTopics(ctx context.Context, userID int, language, level string, questionType models.QuestionType) (result0 []string, err error) {
return w.workerService.GetHighPriorityTopics(ctx, userID, language, level, string(questionType))
}
// getGapAnalysis identifies areas with insufficient questions available
func (w *Worker) getGapAnalysis(ctx context.Context, userID int, language, level string, questionType models.QuestionType) (result0 map[string]int, err error) {
return w.workerService.GetGapAnalysis(ctx, userID, language, level, string(questionType))
}
// getPriorityDistribution returns the distribution of priority scores
func (w *Worker) getPriorityDistribution(ctx context.Context, userID int, language, level string, questionType models.QuestionType) (result0 map[string]int, err error) {
return w.workerService.GetPriorityDistribution(ctx, userID, language, level, string(questionType))
}